def build(self, input_shape):
        self.layers = [
            SignalConv2D(filters=self.num_filters,
                         kernel=(5, 5),
                         name="conv_0",
                         corr=False,
                         strides_up=2,
                         padding="reflect",
                         use_bias=True),
            GDN(name="igdn_0",
                inverse=True),
            SignalConv2D(filters=self.num_filters,
                         kernel=(5, 5),
                         name="conv_1",
                         corr=False,
                         strides_up=2,
                         padding="reflect",
                         use_bias=True),
            GDN(name="igdn_1",
                inverse=True),
            SignalConv2D(filters=3,
                         kernel=(9, 9),
                         name="conv_2",
                         corr=False,
                         strides_up=4,
                         padding="reflect",
                         use_bias=True),
        ]

        super().build(input_shape)
    def __init__(self, num_filters, name="large_1_level_vae", **kwargs):
        super().__init__(name=name, **kwargs)

        self.num_filters = num_filters

        self._prior_base = tf.Variable(tf.zeros((1, 1, 1, self.num_filters)), name="prior_base")

        self._prior_conv = SignalConv2D(filters=self.num_filters,
                                        kernel=(3, 3),
                                        corr=True,
                                        strides_down=1,
                                        padding="reflect",
                                        use_bias=True,
                                        name="prior_conv")

        self._prior_loc_head = SignalConv2D(filters=self.num_filters,
                                            kernel=(3, 3),
                                            corr=True,
                                            strides_down=1,
                                            padding="reflect",
                                            use_bias=True,
                                            name="prior_loc_head")

        self._prior_log_scale_head = SignalConv2D(filters=self.num_filters,
                                                  kernel=(3, 3),
                                                  corr=True,
                                                  strides_down=1,
                                                  padding="reflect",
                                                  use_bias=True,
                                                  name="prior_log_scale_head")

        self.analysis_transform = AnalysisTransform(num_filters=self.num_filters)
        self.synthesis_transform = SynthesisTransform(num_filters=self.num_filters)
    def build(self, input_shape):
        self._prior_base = tf.Variable(tf.zeros((1, 1, 1, self.num_filters)),
                                       name="prior_base")

        self._prior_conv = SignalConv2D(filters=self.num_filters,
                                        kernel=(3, 3),
                                        corr=True,
                                        strides_down=1,
                                        padding="reflect",
                                        use_bias=True,
                                        name="conv_0")

        self._prior_loc_head = SignalConv2D(filters=self.num_filters,
                                            kernel=(3, 3),
                                            corr=True,
                                            strides_down=1,
                                            padding="reflect",
                                            use_bias=True,
                                            name="loc_head")

        self._prior_log_scale_head = SignalConv2D(filters=self.num_filters,
                                                  kernel=(3, 3),
                                                  corr=True,
                                                  strides_down=1,
                                                  padding="reflect",
                                                  use_bias=True,
                                                  name="log_scale_head")

        super().build(input_shape)
    def build(self, input_shape):
        self.layers = [
            SignalConv2D(filters=self.num_filters,
                         kernel=(5, 5),
                         name="conv_0",
                         corr=False,
                         strides_up=2,
                         padding="reflect",
                         use_bias=True,
                         activation=tf.nn.elu),
            SignalConv2D(filters=self.num_filters,
                         kernel=(5, 5),
                         name="conv_1",
                         corr=False,
                         strides_up=2,
                         padding="reflect",
                         use_bias=True,
                         activation=tf.nn.elu),
        ]

        self._prior_loc_head = SignalConv2D(filters=self.num_output_filters,
                                            kernel=(3, 3),
                                            name="prior_loc_head",
                                            corr=False,
                                            strides_up=1,
                                            padding="reflect",
                                            use_bias=True)

        self._prior_log_scale_head = SignalConv2D(
            filters=self.num_output_filters,
            kernel=(3, 3),
            name="prior_log_scale_head",
            corr=False,
            strides_up=1,
            padding="reflect",
            use_bias=True)

        self._deterministic_features_head = SignalConv2D(
            filters=self.num_output_filters,
            kernel=(3, 3),
            name="deterministic_features_head",
            corr=False,
            strides_up=1,
            padding="reflect",
            use_bias=True)

        super().build(input_shape)
    def build(self, input_shape):
        self.layers = [
            SignalConv2D(filters=self.num_filters,
                         kernel=(3, 3),
                         corr=True,
                         strides_down=1,
                         padding="reflect",
                         use_bias=True,
                         name="conv_0",
                         activation=GDN(inverse=False, name="gdn_0")),
            SignalConv2D(filters=self.num_filters,
                         kernel=(3, 3),
                         corr=True,
                         strides_down=1,
                         padding="reflect",
                         use_bias=True,
                         name="conv_1",
                         activation=GDN(inverse=False, name="gdn_1")),
        ]

        self._posterior_loc_head = SignalConv2D(filters=self.num_filters,
                                                kernel=(3, 3),
                                                name="posterior_loc_head",
                                                corr=True,
                                                strides_down=1,
                                                padding="reflect",
                                                use_bias=True)

        self._posterior_log_scale_head = SignalConv2D(
            filters=self.num_filters,
            kernel=(3, 3),
            name="posterior_log_scale_head",
            corr=True,
            strides_down=1,
            padding="reflect",
            use_bias=True)

        self._deterministic_features_head = SignalConv2D(
            filters=self.num_filters,
            kernel=(3, 3),
            name="deterministic_feautres_head",
            corr=True,
            strides_down=1,
            padding="reflect",
            use_bias=True)

        super().build(input_shape)
    def build(self, input_shape):
        self.layers = [
            SignalConv2D(filters=self.num_filters,
                         kernel=(9, 9),
                         name="conv_0",
                         corr=True,
                         strides_down=4,
                         padding="reflect",
                         use_bias=True),
            GDN(inverse=False,
                name="gdn_0"),
            SignalConv2D(filters=self.num_filters,
                         kernel=(5, 5),
                         name="conv_1",
                         corr=True,
                         strides_down=2,
                         padding="reflect",
                         use_bias=True),
            GDN(inverse=False,
                name="gdn_1"),
        ]

        self._posterior_loc_head = SignalConv2D(filters=self.num_filters,
                                                kernel=(5, 5),
                                                name="posterior_loc_head",
                                                corr=True,
                                                strides_down=2,
                                                padding="reflect",
                                                use_bias=False)

        self._posterior_log_scale_head = SignalConv2D(filters=self.num_filters,
                                                      kernel=(5, 5),
                                                      name="posterior_log_scale_head",
                                                      corr=True,
                                                      strides_down=2,
                                                      padding="reflect",
                                                      use_bias=False)

        super().build(input_shape)
    def build(self, input_shape):
        self.layers = [
            SignalConv2D(filters=self.num_filters,
                         kernel=(3, 3),
                         corr=True,
                         strides_down=1,
                         padding="reflect",
                         use_bias=True,
                         name="conv_0",
                         activation=tf.nn.elu),
            SignalConv2D(filters=self.num_filters,
                         kernel=(3, 3),
                         corr=True,
                         strides_down=1,
                         padding="reflect",
                         use_bias=True,
                         name="conv_1",
                         activation=tf.nn.elu),
        ]

        self._posterior_loc_head = SignalConv2D(filters=self.num_filters,
                                                kernel=(3, 3),
                                                name="posterior_loc_head",
                                                corr=True,
                                                strides_down=1,
                                                padding="reflect",
                                                use_bias=True)

        self._posterior_log_scale_head = SignalConv2D(
            filters=self.num_filters,
            kernel=(3, 3),
            name="posterior_log_scale_head",
            corr=True,
            strides_down=1,
            padding="reflect",
            use_bias=True)

        super().build(input_shape)
    def __init__(self,
                 level_1_filters=192,
                 level_2_filters=192,
                 level_3_filters=128,
                 level_4_filters=128,
                 name="large_level_4_vae",
                 **kwargs):
        super().__init__(name=name, **kwargs)

        self.level_1_filters = level_1_filters
        self.level_2_filters = level_2_filters
        self.level_3_filters = level_3_filters
        self.level_4_filters = level_4_filters

        # --------------------------------------------------------------
        # Define the main components of the model
        # --------------------------------------------------------------
        self.analysis_transform = AnalysisTransform(
            num_filters=level_1_filters)
        self.synthesis_transform = SynthesisTransform(
            num_filters=level_1_filters)

        self.extended_analysis_transform = ExtendedAnalysisTransform(
            num_filters=level_2_filters)
        self.extended_synthesis_transform = ExtendedSynthesisTransform(
            num_filters=level_2_filters, num_output_filters=level_1_filters)

        self.hyper_analysis_transform = HyperAnalysisTransform(
            num_filters=level_3_filters)
        self.hyper_synthesis_transform = HyperSynthesisTransform(
            num_filters=level_3_filters, num_output_filters=level_2_filters)

        self.extended_hyper_analysis_transform = ExtendedHyperAnalysisTransform(
            num_filters=level_4_filters)
        self.extended_hyper_synthesis_transform = ExtendedHyperSynthesisTransform(
            num_filters=level_4_filters, num_output_filters=level_3_filters)

        self.empirical_hyper_prior = EmpiricalHyperPrior(
            num_filters=level_4_filters)

        # --------------------------------------------------------------
        # Define inference time residual connectors and combiners
        # --------------------------------------------------------------
        self.inputs_to_level_1_connector = SignalConv2D(
            filters=level_1_filters,
            kernel=(9, 9),
            strides_down=8,
            corr=True,
            padding="reflect",
            use_bias=True,
            name="inputs_to_level_1_connector")

        # We will compose with the inputs to level 1 connector
        self.inputs_to_level_2_connector = tfl.Conv2D(
            filters=level_2_filters,
            kernel_size=(1, 1),
            strides=1,
            padding="valid",
            use_bias=True,
            name="inputs_to_level_2_connector")

        self.level_1_to_level_2_connector = tfl.Conv2D(
            filters=level_2_filters,
            kernel_size=(1, 1),
            strides=1,
            padding="valid",
            use_bias=True,
            name="level_1_to_level_2_connector")

        # We will compose with the inputs to level 1 connector
        self.inputs_to_level_3_connector = SignalConv2D(
            filters=level_3_filters,
            kernel=(5, 5),
            strides_down=4,
            corr=True,
            padding="reflect",
            use_bias=True,
            name="inputs_to_level_3_connector")

        self.level_1_to_level_3_connector = SignalConv2D(
            filters=level_3_filters,
            kernel=(5, 5),
            strides_down=4,
            corr=True,
            padding="reflect",
            use_bias=True,
            name="level_1_to_level_3_connector")

        self.level_2_to_level_3_connector = SignalConv2D(
            filters=level_3_filters,
            kernel=(5, 5),
            strides_down=4,
            corr=True,
            padding="reflect",
            use_bias=True,
            name="level_2_to_level_3_connector")

        self.inference_combiners = [
            tfl.Conv2D(filters=filters,
                       kernel_size=(1, 1),
                       strides=1,
                       padding="valid",
                       use_bias=True,
                       name=f"level_{i + 1}_inference_combiner") for i, filters
            in enumerate([level_1_filters, level_2_filters, level_3_filters])
        ]

        # --------------------------------------------------------------
        # Define generative pass time connectors and combiners
        # --------------------------------------------------------------
        self.level_4_to_level_3_connector = tfl.Conv2D(
            filters=level_3_filters,
            kernel_size=(1, 1),
            strides=1,
            padding="valid",
            use_bias=True,
            name="level_4_to_level_3_connector")

        self.level_4_to_level_2_connector = SignalConv2D(
            filters=level_2_filters,
            kernel=(5, 5),
            strides_up=4,
            corr=False,
            padding="reflect",
            use_bias=True,
            name="level_4_to_level_2_connector")

        self.level_4_to_level_1_connector = SignalConv2D(
            filters=level_1_filters,
            kernel=(5, 5),
            strides_up=4,
            corr=False,
            padding="reflect",
            use_bias=True,
            name="level_4_to_level_1_connector")

        self.level_3_to_level_2_connector = SignalConv2D(
            filters=level_2_filters,
            kernel=(5, 5),
            strides_up=4,
            corr=False,
            padding="reflect",
            use_bias=True,
            name="level_3_to_level_2_connector")

        self.level_3_to_level_1_connector = SignalConv2D(
            filters=level_1_filters,
            kernel=(5, 5),
            strides_up=4,
            corr=False,
            padding="reflect",
            use_bias=True,
            name="level_3_to_level_1_connector")

        self.level_2_to_level_1_connector = tfl.Conv2D(
            filters=level_1_filters,
            kernel_size=(1, 1),
            strides=1,
            padding="valid",
            use_bias=True,
            name="level_2_to_level_1_connector")

        self.generative_combiners = [
            tfl.Conv2D(filters=filters,
                       kernel_size=(1, 1),
                       strides=1,
                       padding="valid",
                       use_bias=True,
                       name=f"level_{i + 1}_generative_combiner")
            for i, filters in enumerate([
                level_1_filters, level_2_filters, level_3_filters,
                level_4_filters
            ])
        ]

        # --------------------------------------------------------------
        # Define posterior statistic combiners
        # --------------------------------------------------------------

        self.posterior_loc_combiners = [
            tfl.Conv2D(filters=filters,
                       kernel_size=(1, 1),
                       strides=1,
                       padding="valid",
                       use_bias=True,
                       name=f"level_{i + 1}_posterior_loc_combiner")
            for i, filters in enumerate([
                level_1_filters, level_2_filters, level_3_filters,
                level_4_filters
            ])
        ]

        self.posterior_log_scale_combiners = [
            tfl.Conv2D(filters=filters,
                       kernel_size=(1, 1),
                       strides=1,
                       padding="valid",
                       use_bias=True,
                       name=f"level_{i + 1}_posterior_log_scale_combiner")
            for i, filters in enumerate([
                level_1_filters, level_2_filters, level_3_filters,
                level_4_filters
            ])
        ]
    def build(self, input_shape):
        # ---------------------------------------------------------------------
        # Stuff for the inference side
        # ---------------------------------------------------------------------

        if not self.is_last:
            self.infer_conv1 = (
                SignalConv2D(filters=self.deterministic_filters,
                             kernel=self.kernel_size,
                             strides_down=1,
                             corr=True,
                             padding="reflect",
                             use_bias=True,
                             name="infer_conv_0") if self.use_sig_convs else
                ReparameterizedConv2D(filters=self.deterministic_filters,
                                      kernel_size=self.kernel_size,
                                      strides=(1, 1),
                                      padding="same",
                                      name="infer_conv_0"))

            self.infer_conv2 = (
                SignalConv2D(filters=self.deterministic_filters,
                             kernel=self.kernel_size,
                             strides_down=1,
                             corr=True,
                             padding="reflect",
                             use_bias=True,
                             name="infer_conv_1") if self.use_sig_convs else
                ReparameterizedConv2D(filters=self.deterministic_filters,
                                      kernel_size=self.kernel_size,
                                      strides=(1, 1),
                                      padding="same",
                                      name="infer_conv_1"))

        self.infer_posterior_loc_head = (
            SignalConv2D(filters=self.stochastic_filters,
                         kernel=self.kernel_size,
                         strides_down=1,
                         corr=True,
                         padding="reflect",
                         use_bias=not self.is_last,
                         name="infer_posterior_loc_head") if self.use_sig_convs
            else ReparameterizedConv2D(filters=self.stochastic_filters,
                                       kernel_size=self.kernel_size,
                                       strides=(1, 1),
                                       padding="same",
                                       name="infer_posterior_loc_head"))

        self.infer_posterior_log_scale_head = (
            SignalConv2D(filters=self.stochastic_filters,
                         kernel=self.kernel_size,
                         strides_down=1,
                         corr=True,
                         padding="reflect",
                         use_bias=not self.is_last,
                         name="infer_posterior_log_scale_head")
            if self.use_sig_convs else ReparameterizedConv2D(
                filters=self.stochastic_filters,
                kernel_size=self.kernel_size,
                strides=(1, 1),
                padding="same",
                name="infer_posterior_log_scale_head"))

        # ---------------------------------------------------------------------
        # Stuff for the generative side
        # Note: In the general case, these should technically be deconvolutions, but
        # in the original implementation the dimensions within a single block do not
        # decrease, hence there is not much point in using the more expensive operation
        # ---------------------------------------------------------------------
        self.gen_conv1 = (SignalConv2D(filters=self.deterministic_filters,
                                       kernel=self.kernel_size,
                                       strides_up=1,
                                       corr=False,
                                       padding="reflect",
                                       use_bias=True,
                                       name="gen_conv_0")
                          if self.use_sig_convs else ReparameterizedConv2D(
                              filters=self.deterministic_filters,
                              kernel_size=self.kernel_size,
                              strides=(1, 1),
                              padding="same",
                              name="gen_conv_0"))

        self.gen_conv2 = (SignalConv2D(filters=self.deterministic_filters,
                                       kernel=self.kernel_size,
                                       strides_up=1,
                                       corr=False,
                                       padding="reflect",
                                       use_bias=True,
                                       name="gen_conv_1")
                          if self.use_sig_convs else ReparameterizedConv2D(
                              filters=self.deterministic_filters,
                              kernel_size=self.kernel_size,
                              strides=(1, 1),
                              padding="same",
                              name="gen_conv_1"))

        self.prior_loc_head = (SignalConv2D(filters=self.stochastic_filters,
                                            kernel=self.kernel_size,
                                            strides_up=1,
                                            corr=False,
                                            padding="reflect",
                                            use_bias=True,
                                            name="prior_loc_head") if
                               self.use_sig_convs else ReparameterizedConv2D(
                                   filters=self.stochastic_filters,
                                   kernel_size=self.kernel_size,
                                   strides=(1, 1),
                                   padding="same",
                                   name="prior_loc_head"))

        self.prior_log_scale_head = (
            SignalConv2D(filters=self.stochastic_filters,
                         kernel=self.kernel_size,
                         strides_up=1,
                         corr=False,
                         padding="reflect",
                         use_bias=True,
                         name="prior_log_scale_head") if self.use_sig_convs
            else ReparameterizedConv2D(filters=self.stochastic_filters,
                                       kernel_size=self.kernel_size,
                                       strides=(1, 1),
                                       padding="same",
                                       name="prior_log_scale_head"))

        self.gen_posterior_loc_head = (
            SignalConv2D(filters=self.stochastic_filters,
                         kernel=self.kernel_size,
                         strides_up=1,
                         corr=False,
                         padding="reflect",
                         use_bias=True,
                         name="gen_posterior_loc_head") if self.use_sig_convs
            else ReparameterizedConv2D(filters=self.stochastic_filters,
                                       kernel_size=self.kernel_size,
                                       strides=(1, 1),
                                       padding="same",
                                       name="gen_posterior_loc_head"))

        self.gen_posterior_log_scale_head = (
            SignalConv2D(filters=self.stochastic_filters,
                         kernel=self.kernel_size,
                         strides_up=1,
                         corr=False,
                         padding="reflect",
                         use_bias=True,
                         name="gen_posterior_log_scale_head")
            if self.use_sig_convs else ReparameterizedConv2D(
                filters=self.stochastic_filters,
                kernel_size=self.kernel_size,
                strides=(1, 1),
                padding="same",
                name="gen_posterior_log_scale_head"))

        # ---------------------------------------------------------------------
        # If we use IAF posteriors, we need some additional layers
        # ---------------------------------------------------------------------
        if self.use_iaf:
            self.infer_iaf_autoregressive_context_conv = ReparameterizedConv2D(
                filters=self.deterministic_filters,
                kernel_size=self.kernel_size,
                strides=(1, 1),
                padding="same")

            self.gen_iaf_autoregressive_context_conv = ReparameterizedConv2D(
                filters=self.deterministic_filters,
                kernel_size=self.kernel_size,
                strides=(1, 1),
                padding="same")

            self.iaf_posterior_multiconv = AutoRegressiveMultiConv2D(
                convolution_filters=[
                    self.deterministic_filters, self.deterministic_filters
                ],
                head_filters=[
                    self.stochastic_filters, self.stochastic_filters
                ])

        super().build(input_shape=input_shape)
示例#10
0
    def __init__(self,
                 sampler,
                 sampler_args={},
                 coder_args={},
                 use_gdn=True,
                 use_sig_convs=True,
                 distribution="gaussian",
                 likelihood_function="laplace",
                 learn_likelihood_scale=False,
                 first_kernel_size=(5, 5),
                 first_strides=(2, 2),
                 kernel_size=(5, 5),
                 strides=(2, 2),
                 first_deterministic_filters=160,
                 second_deterministic_filters=160,
                 first_stochastic_filters=128,
                 second_stochastic_filters=32,
                 kl_per_partition=10,
                 latent_size="variable",
                 ema_decay=0.999,
                 name="resnet_vae",
                 **kwargs):
        super().__init__(name=name,
                         **kwargs)

        # ---------------------------------------------------------------------
        # Assign hyperparamteres
        # ---------------------------------------------------------------------

        self.distribution = distribution

        self.sampler_name = str(sampler)
        self.learn_likelihood_scale = learn_likelihood_scale

        if likelihood_function not in self.AVAILABLE_LIKELIHOODS:
            raise ModelError(f"Likelihood function must be one of: {self.AVAILABLE_LIKELIHOODS}! "
                             f"({likelihood_function} was given).")

        self._likelihood_function = likelihood_function

        self.first_kernel_size = first_kernel_size
        self.first_strides = first_strides

        self.kernel_size = kernel_size
        self.strides = strides
        self.first_stochastic_filters = first_stochastic_filters
        self.first_deterministic_filters = first_deterministic_filters
        self.second_stochastic_filters = second_stochastic_filters
        self.second_deterministic_filters = second_deterministic_filters

        self.kl_per_partition = kl_per_partition
        # Decay for exponential moving average update to variables
        self.ema_decay = tf.cast(ema_decay, tf.float32)

        # ---------------------------------------------------------------------
        # Create parameters
        # ---------------------------------------------------------------------
        self.likelihood_log_scale = tf.Variable(0.,
                                                name="likelihood_log_scale",
                                                trainable=self.learn_likelihood_scale)

        # ---------------------------------------------------------------------
        # Create ResNet Layers
        # ---------------------------------------------------------------------
        # The first deterministic inference block downsamples 8x8
        # Note: we don't apply an ELU at the end of the block, this will happen
        # in the residual block
        self.first_infer_block = [
            (SignalConv2D(kernel=(5, 5),
                          corr=True,
                          strides_down=2,
                          filters=self.first_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="infer_sig_conv_0_0")
             if use_sig_convs else
             ReparameterizedConv2D(kernel_size=(5, 5),
                                   strides=2,
                                   filters=self.first_deterministic_filters,
                                   padding="same")
             ),

            (GDN(inverse=False, name="inf_gdn_0") if use_gdn else tf.nn.elu),

            (SignalConv2D(kernel=(5, 5),
                          corr=True,
                          strides_down=2,
                          filters=self.first_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="infer_sig_conv_0_1")
             if use_sig_convs else
             ReparameterizedConv2D(kernel_size=(5, 5),
                                   strides=2,
                                   filters=self.first_deterministic_filters,
                                   padding="same")
             ),

            (GDN(inverse=False, name="inf_gdn_1") if use_gdn else tf.nn.elu),

            (SignalConv2D(kernel=(5, 5),
                          corr=True,
                          strides_down=2,
                          filters=self.first_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="infer_sig_conv_0_2")
             if use_sig_convs else
             ReparameterizedConv2D(kernel_size=(5, 5),
                                   strides=2,
                                   filters=self.first_deterministic_filters,
                                   padding="same")
             ),

            (GDN(inverse=False, name="inf_gdn_2") if use_gdn else tf.nn.elu),

            (SignalConv2D(kernel=(5, 5),
                          corr=True,
                          strides_down=2,
                          filters=self.first_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="infer_sig_conv_0_3")
             if use_sig_convs else
             ReparameterizedConv2D(kernel_size=(5, 5),
                                   strides=2,
                                   filters=self.first_deterministic_filters,
                                   padding="same")
             ),

        ]

        # The first deterministic generative block is the pseudoinverse of the inference block
        self.first_gen_block = [
            tf.nn.elu,
            (SignalConv2D(kernel=(5, 5),
                          strides_up=2,
                          filters=self.first_deterministic_filters,
                          corr=False,
                          padding="reflect",
                          use_bias=True,
                          name="gen_sig_conv_0_0")
             if use_sig_convs else
             ReparameterizedConv2DTranspose(kernel_size=(5, 5),
                                            strides=2,
                                            filters=self.first_deterministic_filters,
                                            padding="same")),

            (GDN(inverse=True, name="gen_gdn_0") if use_gdn else tf.nn.elu),

            (SignalConv2D(kernel=(5, 5),
                          strides_up=2,
                          filters=self.first_deterministic_filters,
                          corr=False,
                          padding="reflect",
                          use_bias=True,
                          name="gen_sig_conv_0_1")
             if use_sig_convs else
             ReparameterizedConv2DTranspose(kernel_size=(5, 5),
                                            strides=2,
                                            filters=self.first_deterministic_filters,
                                            padding="same")),

            (GDN(inverse=True, name="gen_gdn_1") if use_gdn else tf.nn.elu),

            (SignalConv2D(kernel=(5, 5),
                          strides_up=2,
                          filters=self.first_deterministic_filters,
                          corr=False,
                          padding="reflect",
                          use_bias=True,
                          name="gen_sig_conv_0_2")
             if use_sig_convs else
             ReparameterizedConv2DTranspose(kernel_size=(5, 5),
                                            strides=2,
                                            filters=self.first_deterministic_filters,
                                            padding="same")),

            (GDN(inverse=True, name="gen_gdn_2") if use_gdn else tf.nn.elu),

            (SignalConv2D(kernel=(5, 5),
                          strides_up=2,
                          filters=3,
                          corr=False,
                          padding="reflect",
                          use_bias=True,
                          name="gen_sig_conv_0_3")
             if use_sig_convs else
             ReparameterizedConv2DTranspose(kernel_size=(5, 5),
                                            strides=2,
                                            filters=3,
                                            padding="same")),
        ]

        # The second deterministic inference block downsamples by another 4x4
        self.second_infer_block = [
            (SignalConv2D(kernel=(3, 3),
                          strides_down=1,
                          corr=True,
                          filters=self.second_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="infer_sig_conv_1_0")
             if use_sig_convs else
             ReparameterizedConv2D(kernel_size=(3, 3),
                                   strides=1,
                                   filters=self.second_deterministic_filters,
                                   padding="same")),
            tf.nn.elu,
            (SignalConv2D(kernel=(5, 5),
                          strides_down=2,
                          corr=True,
                          filters=self.second_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="infer_sig_conv_1_1")
             if use_sig_convs else
             ReparameterizedConv2D(kernel_size=(5, 5),
                                   strides=2,
                                   filters=self.second_deterministic_filters,
                                   padding="same")),
            tf.nn.elu,
            (SignalConv2D(kernel=(5, 5),
                          strides_down=2,
                          corr=True,
                          filters=self.second_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="infer_sig_conv_1_2")
             if use_sig_convs else
             ReparameterizedConv2D(kernel_size=(5, 5),
                                   strides=2,
                                   filters=self.second_deterministic_filters,
                                   padding="same")),
        ]

        self.second_gen_block = [
            tf.nn.elu,
            (SignalConv2D(kernel=(5, 5),
                          strides_up=2,
                          corr=False,
                          filters=self.second_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="gen_sig_conv_1_0")
             if use_sig_convs else
             ReparameterizedConv2DTranspose(kernel_size=(5, 5),
                                            strides=2,
                                            filters=self.second_deterministic_filters,
                                            padding="same")),
            tf.nn.elu,
            (SignalConv2D(kernel=(5, 5),
                          strides_up=2,
                          corr=False,
                          filters=self.second_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="gen_sig_conv_1_1")
             if use_sig_convs else
             ReparameterizedConv2DTranspose(kernel_size=(5, 5),
                                            strides=2,
                                            filters=self.second_deterministic_filters,
                                            padding="same")),
            tf.nn.elu,
            (SignalConv2D(kernel=(3, 3),
                          strides_up=1,
                          corr=False,
                          filters=self.first_deterministic_filters,
                          padding="reflect",
                          use_bias=True,
                          name="gen_sig_conv_1_2")
             if use_sig_convs else
             ReparameterizedConv2DTranspose(kernel_size=(3, 3),
                                            strides=1,
                                            filters=self.first_deterministic_filters,
                                            padding="same")),
        ]

        # Create Stochastic Residual Blocks
        self.first_residual_block = BidirectionalResidualBlock(
            stochastic_filters=self.first_stochastic_filters,
            deterministic_filters=self.first_deterministic_filters,
            sampler=self.sampler_name,
            sampler_args=sampler_args,
            coder_args=coder_args,
            distribution=distribution,
            kernel_size=self.kernel_size,
            is_last=False,
            use_iaf=False,
            kl_per_partition=self.kl_per_partition,
            use_sig_convs=use_sig_convs,
            name=f"resnet_block_1"
        )

        self.second_residual_block = BidirectionalResidualBlock(
            stochastic_filters=self.second_stochastic_filters,
            deterministic_filters=self.second_deterministic_filters,
            sampler=self.sampler_name,
            sampler_args=sampler_args,
            coder_args=coder_args,
            distribution=distribution,
            kernel_size=self.kernel_size,
            is_last=True,
            use_iaf=False,
            kl_per_partition=self.kl_per_partition,
            use_sig_convs=use_sig_convs,
            name=f"resnet_block_2"
        )

        self.residual_blocks = [self.first_residual_block,
                                self.second_residual_block]

        # Likelihood distribution
        self.likelihood_dist = None

        # Likelihood of the most recent sample
        self.log_likelihood = -np.inf

        # this variable will allow us to perform Empirical Bayes on the first prior
        # Referred to as "h_top" in both the Kingma and Townsend implementations
        self._generative_base = tf.Variable(tf.zeros(self.second_deterministic_filters),
                                            name="generative_base")

        # ---------------------------------------------------------------------
        # EMA shadow variables
        # ---------------------------------------------------------------------
        self._ema_shadow_variables = {}

        self._compressor_initialized = tf.Variable(False, name="compressor_initialized", trainable=False)