Ejemplo n.º 1
0
def reverse_pixel_cnn_28_binary(x,
                                masks,
                                context=None,
                                nr_logistic_mix=10,
                                nr_resnet=1,
                                nr_filters=100,
                                dropout_p=0.0,
                                nonlinearity=None,
                                bn=True,
                                kernel_initializer=None,
                                kernel_regularizer=None,
                                is_training=False,
                                counters={}):
    name = get_name("reverse_pixel_cnn_28_binary", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=None,
                       sh=context,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p):
            with arg_scope([
                    gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d,
                    up_shifted_deconv2d, up_left_shifted_deconv2d
            ],
                           bn=bn,
                           kernel_initializer=kernel_initializer,
                           kernel_regularizer=kernel_regularizer,
                           is_training=is_training,
                           counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))

                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                return x_out
def conditioning_network_28(x,
                            masks,
                            nr_filters,
                            is_training=True,
                            nonlinearity=None,
                            bn=True,
                            kernel_initializer=None,
                            kernel_regularizer=None,
                            counters={}):
    name = get_name("conditioning_network_28", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    xs = int_shape(x)
    x = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)
    with tf.variable_scope(name):
        with arg_scope([conv2d, residual_block, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            outputs = conv2d(x, nr_filters, 4, 1, "SAME")
            for l in range(4):
                outputs = conv2d(outputs, nr_filters, 4, 1, "SAME")
            outputs = conv2d(outputs,
                             nr_filters,
                             1,
                             1,
                             "SAME",
                             nonlinearity=None,
                             bn=False)
            return outputs
def context_encoder(contexts,
                    masks,
                    is_training,
                    nr_resnet=5,
                    nr_filters=100,
                    nonlinearity=None,
                    bn=False,
                    kernel_initializer=None,
                    kernel_regularizer=None,
                    counters={}):
    name = get_name("context_encoder", counters)
    print("construct", name, "...")
    x = contexts * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    if bn:
        print("*** Attention *** using bn in the context encoder\n")
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       nonlinearity=nonlinearity,
                       counters=counters):
            with arg_scope(
                [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d],
                    bn=bn,
                    kernel_initializer=kernel_initializer,
                    kernel_regularizer=kernel_regularizer,
                    is_training=is_training):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))
                    receptive_field = (receptive_field[0] + 1,
                                       receptive_field[1] + 2)
                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                print("    * receptive_field", receptive_field)
                return x_out
    def __model(self, network_type="large"):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            num_channels = 1
        else:
            num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        self.use_prior = tf.placeholder_with_default(False, shape=())
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large
                decoder = conv_decoder_32_large
            else:
                encoder = conv_encoder_32
                decoder = conv_decoder_32
            forward_pixelcnn = forward_pixel_cnn_32_small
            reverse_pixelcnn = reverse_pixel_cnn_32_small
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder],
                       **kwargs):
            kwargs_pixelcnn = {
                "nr_resnet": self.nr_resnet,
                "nr_filters": self.nr_filters,
                "nr_logistic_mix": self.nr_logistic_mix,
                "dropout_p": self.dropout_p,
                "bn": False,
            }
            with arg_scope([forward_pixelcnn, reverse_pixelcnn],
                           **kwargs_pixelcnn):
                inputs = self.x
                if self.input_masks is not None:
                    inputs = inputs * broadcast_masks_tf(self.input_masks,
                                                         num_channels=3)
                    inputs += tf.random_uniform(int_shape(inputs), -1, 1) * (
                        1 -
                        broadcast_masks_tf(self.input_masks, num_channels=3))
                    inputs = tf.concat([
                        inputs,
                        broadcast_masks_tf(self.input_masks, num_channels=1)
                    ],
                                       axis=-1)

                self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim)
                sigma = tf.exp(self.z_log_sigma_sq / 2.)
                self.z = gaussian_sampler(self.z_mu, sigma)
                self.z_pr = gaussian_sampler(tf.zeros_like(self.z_mu),
                                             tf.ones_like(sigma))

                self.z_ph = tf.placeholder_with_default(
                    tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu))
                self.use_z_ph = tf.placeholder_with_default(False, shape=())

                use_prior = tf.cast(tf.cast(self.use_prior, tf.int32),
                                    tf.float32)
                use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32),
                                   tf.float32)
                z = (use_prior * self.z_pr + (1 - use_prior) * self.z) * (
                    1 - use_z_ph) + use_z_ph * self.z_ph

                decoded_features = decoder(z, output_features=True)
                r_outputs = reverse_pixelcnn(self.x_bar, self.masks, bn=False)
                cond_features = tf.concat([r_outputs, decoded_features],
                                          axis=-1)
                self.mix_logistic_params = forward_pixelcnn(self.x_bar,
                                                            cond_features,
                                                            bn=False)
                self.x_hat = mix_logistic_sampler(
                    self.mix_logistic_params,
                    nr_logistic_mix=self.nr_logistic_mix,
                    sample_range=self.sample_range,
                    counters=self.counters)
 def __model(self,
             x,
             x_bar,
             is_training,
             dropout_p,
             masks,
             input_masks,
             network_size="medium"):
     print("******   Building Graph   ******")
     self.x = x
     self.x_bar = x_bar
     self.is_training = is_training
     self.dropout_p = dropout_p
     self.masks = masks
     self.input_masks = input_masks
     if int_shape(x)[1] == 64:
         conv_encoder = conv_encoder_64_medium
         conv_decoder = conv_decoder_64_medium
     elif int_shape(x)[1] == 32:
         if network_size == 'medium':
             conv_encoder = conv_encoder_32_medium
             conv_decoder = conv_decoder_32_medium
         elif network_size == 'large':
             conv_encoder = conv_encoder_32_large
             conv_decoder = conv_decoder_32_large
         elif network_size == 'large1':
             conv_encoder = conv_encoder_32_large1
             conv_decoder = conv_decoder_32_large1
         else:
             raise Exception("unknown network type")
     with arg_scope(
         [conv_encoder, conv_decoder, context_encoder, cond_pixel_cnn],
             nonlinearity=self.nonlinearity,
             bn=self.bn,
             kernel_initializer=self.kernel_initializer,
             kernel_regularizer=self.kernel_regularizer,
             is_training=self.is_training,
             counters=self.counters):
         inputs = self.x
         if self.input_masks is not None:
             inputs = inputs * broadcast_masks_tf(self.input_masks,
                                                  num_channels=3)
             inputs += tf.random_uniform(int_shape(inputs), -1, 1) * (
                 1 - broadcast_masks_tf(self.input_masks, num_channels=3))
             inputs = tf.concat([
                 inputs,
                 broadcast_masks_tf(self.input_masks, num_channels=1)
             ],
                                axis=-1)
         self.z_mu, self.z_log_sigma_sq = conv_encoder(inputs, self.z_dim)
         sigma = tf.exp(self.z_log_sigma_sq / 2.)
         if self.use_mode == 'train':
             self.z = gaussian_sampler(self.z_mu, sigma)
         elif self.use_mode == 'test':
             self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu))
         print("use mode:{0}".format(self.use_mode))
         self.decoded_features = conv_decoder(self.z, output_features=True)
         if self.masks is None:
             sh = self.decoded_features
         else:
             self.encoded_context = context_encoder(
                 self.x,
                 self.masks,
                 bn=False,
                 nr_resnet=self.nr_resnet,
                 nr_filters=self.nr_filters)
             sh = tf.concat([self.decoded_features, self.encoded_context],
                            axis=-1)
             if self.input_masks is not None:
                 sh = tf.concat([
                     broadcast_masks_tf(self.input_masks, num_channels=1),
                     sh
                 ],
                                axis=-1)
         self.mix_logistic_params = cond_pixel_cnn(
             self.x_bar,
             sh=sh,
             bn=False,
             dropout_p=self.dropout_p,
             nr_resnet=self.nr_resnet,
             nr_filters=self.nr_filters,
             nr_logistic_mix=self.nr_logistic_mix)
         self.x_hat = mix_logistic_sampler(
             self.mix_logistic_params,
             nr_logistic_mix=self.nr_logistic_mix,
             sample_range=self.sample_range,
             counters=self.counters)
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large_mixture_logistic
                encoder_q = conv_encoder_32_q
            else:
                raise Exception("unknown network type")
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
                encoder_q = conv_encoder_28_binary_q
            else:
                raise Exception("unknown network type")
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([encoder, decoder, encoder_q], **kwargs):
            inputs = self.x

            self.num_particles = 16

            inputs = inputs * broadcast_masks_tf(
                self.masks, num_channels=self.num_channels)
            inputs = tf.concat(
                [inputs,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            inputs_pos = tf.concat(
                [self.x,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            inputs = tf.concat([inputs, inputs_pos], axis=0)

            z_mu, z_log_sigma_sq = encoder(inputs, self.z_dim)
            self.z_mu_pr, self.z_mu = z_mu[:self.batch_size], z_mu[self.
                                                                   batch_size:]
            self.z_log_sigma_sq_pr, self.z_log_sigma_sq = z_log_sigma_sq[:self.batch_size], z_log_sigma_sq[
                self.batch_size:]

            self.z_mu, self.z_log_sigma_sq = self.z_mu_pr, self.z_log_sigma_sq_pr

            x = tf.tile(self.x, [self.num_particles, 1, 1, 1])
            masks = tf.tile(self.masks, [self.num_particles, 1, 1])
            self.z_mu = tf.tile(self.z_mu, [self.num_particles, 1])
            self.z_mu_pr = tf.tile(self.z_mu_pr, [self.num_particles, 1])
            self.z_log_sigma_sq = tf.tile(self.z_log_sigma_sq,
                                          [self.num_particles, 1])
            self.z_log_sigma_sq_pr = tf.tile(self.z_log_sigma_sq_pr,
                                             [self.num_particles, 1])
            sigma = tf.exp(self.z_log_sigma_sq / 2.)

            self.params = get_trainable_variables(["inference"])

            dist = tf.distributions.Normal(loc=0., scale=1.)
            epsilon = dist.sample(sample_shape=[
                self.batch_size * self.num_particles, self.z_dim
            ],
                                  seed=None)
            z = self.z_mu + tf.multiply(epsilon, sigma)

            if self.network_type == 'binary':
                self.pixel_params = decoder(z)
            else:
                self.pixel_params = decoder(
                    z, nr_logistic_mix=self.nr_logistic_mix)
            if self.network_type == 'binary':
                nll = bernoulli_loss(x,
                                     self.pixel_params,
                                     masks=masks,
                                     output_mean=False)
            else:
                nll = mix_logistic_loss(x,
                                        self.pixel_params,
                                        masks=masks,
                                        output_mean=False)

            log_prob_pos = dist.log_prob(epsilon)
            epsilon_pr = (z - self.z_mu_pr) / tf.exp(
                self.z_log_sigma_sq_pr / 2.)
            log_prob_pr = dist.log_prob(epsilon_pr)
            # convert back
            log_prob_pr = tf.stack([
                log_prob_pr[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                                   axis=0)
            log_prob_pos = tf.stack([
                log_prob_pos[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                                    axis=0)
            log_prob_pr = tf.reduce_sum(log_prob_pr, axis=2)
            log_prob_pos = tf.reduce_sum(log_prob_pos, axis=2)
            nll = tf.stack([
                nll[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                           axis=0)
            log_likelihood = -nll

            # log_weights = log_prob_pr + log_likelihood - log_prob_pos
            log_weights = log_likelihood
            log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0)
            log_avg_weight = log_sum_weight - tf.log(
                tf.to_float(self.num_particles))
            self.log_avg_weight = log_avg_weight

            normalized_weights = tf.stop_gradient(
                tf.nn.softmax(log_weights, axis=0))
            sq_normalized_weights = tf.square(normalized_weights)

            self.gradients = tf.gradients(
                -tf.reduce_sum(sq_normalized_weights * log_weights, axis=0),
                self.params,
                colocate_gradients_with_ops=True)
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large_mixture_logistic
            else:
                raise Exception("unknown network type")
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
            else:
                raise Exception("unknown network type")
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([encoder, decoder], **kwargs):
            inputs = self.x
            inputs = inputs * broadcast_masks_tf(
                self.masks, num_channels=self.num_channels)
            inputs = tf.concat(
                [inputs,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim)
            sigma = tf.exp(self.z_log_sigma_sq / 2.)
            self.z = gaussian_sampler(self.z_mu, sigma)

            self.z_ph = tf.placeholder_with_default(tf.zeros_like(self.z_mu),
                                                    shape=int_shape(self.z_mu))
            self.use_z_ph = tf.placeholder_with_default(False, shape=())

            use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32)
            z = self.z * (1 - use_z_ph) + use_z_ph * self.z_ph

            if self.network_type == 'binary':
                self.pixel_params = decoder(z)
                self.x_hat = bernoulli_sampler(self.pixel_params,
                                               counters=self.counters)
            else:
                self.pixel_params = decoder(
                    z, nr_logistic_mix=self.nr_logistic_mix)
                self.x_hat = mix_logistic_sampler(
                    self.pixel_params,
                    nr_logistic_mix=self.nr_logistic_mix,
                    sample_range=self.sample_range,
                    counters=self.counters)
Ejemplo n.º 8
0
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large
                encoder_q = conv_encoder_32_q
            else:
                encoder = conv_encoder_32
                decoder = conv_decoder_32
            forward_pixelcnn = forward_pixel_cnn_32_small
            reverse_pixelcnn = reverse_pixel_cnn_32_small
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
                forward_pixelcnn = forward_pixel_cnn_28_binary
                reverse_pixelcnn = reverse_pixel_cnn_28_binary
                encoder_q = conv_encoder_28_binary_q
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope(
            [forward_pixelcnn, reverse_pixelcnn, encoder, decoder, encoder_q],
                **kwargs):
            kwargs_pixelcnn = {
                "nr_resnet": self.nr_resnet,
                "nr_filters": self.nr_filters,
                "nr_logistic_mix": self.nr_logistic_mix,
                "dropout_p": self.dropout_p,
                "bn": False,
            }
            with arg_scope([forward_pixelcnn, reverse_pixelcnn],
                           **kwargs_pixelcnn):
                self.num_particles = 16

                inp = self.x * broadcast_masks_tf(
                    self.input_masks, num_channels=self.num_channels)
                inp += tf.random_uniform(
                    int_shape(inp), -1, 1) * (1 - broadcast_masks_tf(
                        self.input_masks, num_channels=self.num_channels))
                inp = tf.concat([
                    inp,
                    broadcast_masks_tf(self.input_masks, num_channels=1)
                ],
                                axis=-1)

                inputs_pos = tf.concat([
                    self.x,
                    broadcast_masks_tf(tf.ones_like(self.input_masks),
                                       num_channels=1)
                ],
                                       axis=-1)
                inp = tf.concat([inp, inputs_pos], axis=0)

                z_mu, z_log_sigma_sq = encoder(inp, self.z_dim)
                self.z_mu_pr, self.z_mu = z_mu[:self.batch_size], z_mu[
                    self.batch_size:]
                self.z_log_sigma_sq_pr, self.z_log_sigma_sq = z_log_sigma_sq[:self.batch_size], z_log_sigma_sq[
                    self.batch_size:]

                x = tf.tile(self.x, [self.num_particles, 1, 1, 1])
                x_bar = tf.tile(self.x_bar, [self.num_particles, 1, 1, 1])
                input_masks = tf.tile(self.input_masks,
                                      [self.num_particles, 1, 1])
                masks = tf.tile(self.masks, [self.num_particles, 1, 1])

                self.z_mu_pr = tf.tile(self.z_mu_pr, [self.num_particles, 1])
                self.z_log_sigma_sq_pr = tf.tile(self.z_log_sigma_sq_pr,
                                                 [self.num_particles, 1])
                self.z_mu = tf.tile(self.z_mu, [self.num_particles, 1])
                self.z_log_sigma_sq = tf.tile(self.z_log_sigma_sq,
                                              [self.num_particles, 1])

                self.z_mu, self.z_log_sigma_sq = self.z_mu_pr, self.z_log_sigma_sq_pr

                sigma = tf.exp(self.z_log_sigma_sq / 2.)

                self.params = get_trainable_variables(["inference"])

                dist = tf.distributions.Normal(loc=0., scale=1.)
                epsilon = dist.sample(sample_shape=[
                    self.batch_size * self.num_particles, self.z_dim
                ],
                                      seed=None)
                z = self.z_mu + tf.multiply(epsilon, sigma)

                decoded_features = decoder(z, output_features=True)
                r_outputs = reverse_pixelcnn(x, masks, context=None, bn=False)
                cond_features = tf.concat([r_outputs, decoded_features],
                                          axis=-1)
                cond_features = tf.concat([
                    broadcast_masks_tf(input_masks, num_channels=1),
                    cond_features
                ],
                                          axis=-1)

                self.pixel_params = forward_pixelcnn(x_bar,
                                                     cond_features,
                                                     bn=False)

                if self.network_type == 'binary':
                    nll = bernoulli_loss(x,
                                         self.pixel_params,
                                         masks=masks,
                                         output_mean=False)
                else:
                    nll = mix_logistic_loss(x,
                                            self.pixel_params,
                                            masks=masks,
                                            output_mean=False)

                log_prob_pos = dist.log_prob(epsilon)
                epsilon_pr = (z - self.z_mu_pr) / tf.exp(
                    self.z_log_sigma_sq_pr / 2.)
                log_prob_pr = dist.log_prob(epsilon_pr)
                # convert back
                log_prob_pr = tf.stack([
                    log_prob_pr[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                                       axis=0)
                log_prob_pos = tf.stack([
                    log_prob_pos[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                                        axis=0)
                log_prob_pr = tf.reduce_sum(log_prob_pr, axis=2)
                log_prob_pos = tf.reduce_sum(log_prob_pos, axis=2)
                nll = tf.stack([
                    nll[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                               axis=0)
                log_likelihood = -nll

                # log_weights = log_prob_pr + log_likelihood - log_prob_pos
                log_weights = log_likelihood
                log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0)
                log_avg_weight = log_sum_weight - tf.log(
                    tf.to_float(self.num_particles))
                self.log_avg_weight = log_avg_weight

                normalized_weights = tf.stop_gradient(
                    tf.nn.softmax(log_weights, axis=0))
                sq_normalized_weights = tf.square(normalized_weights)
                self.gradients = tf.gradients(-tf.reduce_sum(
                    sq_normalized_weights * log_weights, axis=0),
                                              self.params,
                                              colocate_gradients_with_ops=True)