def __loss(self):
     print("******   Compute Loss   ******")
     if self.network_type == 'binary':
         self.loss_nll = bernoulli_loss(self.x,
                                        self.pixel_params,
                                        masks=self.masks,
                                        output_mean=False)
     else:
         self.loss_nll = mix_logistic_loss(self.x,
                                           self.pixel_params,
                                           masks=self.masks,
                                           output_mean=False)
     self.bits_per_dim = tf.reduce_mean(
         bits_per_dim_tf(nll=self.loss_nll,
                         dim=tf.reduce_sum(1 - self.masks, axis=[1, 2]) *
                         self.num_channels))
     self.loss_nll = tf.reduce_mean(self.loss_nll)
     self.lam = 0.0
     if self.reg_type is None:
         self.loss_reg = 0
     elif self.reg_type == 'kld':
         self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq)
         self.loss_reg = self.beta * tf.maximum(self.lam, self.kld)
     elif self.reg_type == 'mmd':
         # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z)
         self.mmd = estimate_mmd(
             tf.random_normal(tf.stack([256, self.z_dim])), self.z)
         self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd)
     elif self.reg_type == 'tc':
         self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
             self.z, self.z_mu, self.z_log_sigma_sq, N=10000)
         self.loss_reg = self.mi + self.beta * self.tc + self.dwkld
     self.loss = self.loss_nll + self.loss_reg
 def __loss(self):
     print("******   Compute Loss   ******")
     if self.network_type == 'binary':
         self.loss = bernoulli_loss(self.x, self.pixel_params, masks=self.masks, output_mean=False)
     else:
         self.loss = mix_logistic_loss(self.x, self.pixel_params, masks=self.masks, output_mean=False)
     self.bits_per_dim = tf.reduce_mean(bits_per_dim_tf(nll=self.loss, dim=tf.reduce_sum(1-self.masks, axis=[1,2])*self.num_channels))
     self.loss = tf.reduce_mean(self.loss)
     self.loss_nll = self.loss
 def _loss(self, x, outputs):
     l = tf.reduce_mean(bernoulli_loss(x, outputs, sum_all=False))
     return l
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large_mixture_logistic
                encoder_q = conv_encoder_32_q
            else:
                raise Exception("unknown network type")
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
                encoder_q = conv_encoder_28_binary_q
            else:
                raise Exception("unknown network type")
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([encoder, decoder, encoder_q], **kwargs):
            inputs = self.x

            self.num_particles = 16

            inputs = inputs * broadcast_masks_tf(
                self.masks, num_channels=self.num_channels)
            inputs = tf.concat(
                [inputs,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            inputs_pos = tf.concat(
                [self.x,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            inputs = tf.concat([inputs, inputs_pos], axis=0)

            z_mu, z_log_sigma_sq = encoder(inputs, self.z_dim)
            self.z_mu_pr, self.z_mu = z_mu[:self.batch_size], z_mu[self.
                                                                   batch_size:]
            self.z_log_sigma_sq_pr, self.z_log_sigma_sq = z_log_sigma_sq[:self.batch_size], z_log_sigma_sq[
                self.batch_size:]

            self.z_mu, self.z_log_sigma_sq = self.z_mu_pr, self.z_log_sigma_sq_pr

            x = tf.tile(self.x, [self.num_particles, 1, 1, 1])
            masks = tf.tile(self.masks, [self.num_particles, 1, 1])
            self.z_mu = tf.tile(self.z_mu, [self.num_particles, 1])
            self.z_mu_pr = tf.tile(self.z_mu_pr, [self.num_particles, 1])
            self.z_log_sigma_sq = tf.tile(self.z_log_sigma_sq,
                                          [self.num_particles, 1])
            self.z_log_sigma_sq_pr = tf.tile(self.z_log_sigma_sq_pr,
                                             [self.num_particles, 1])
            sigma = tf.exp(self.z_log_sigma_sq / 2.)

            self.params = get_trainable_variables(["inference"])

            dist = tf.distributions.Normal(loc=0., scale=1.)
            epsilon = dist.sample(sample_shape=[
                self.batch_size * self.num_particles, self.z_dim
            ],
                                  seed=None)
            z = self.z_mu + tf.multiply(epsilon, sigma)

            if self.network_type == 'binary':
                self.pixel_params = decoder(z)
            else:
                self.pixel_params = decoder(
                    z, nr_logistic_mix=self.nr_logistic_mix)
            if self.network_type == 'binary':
                nll = bernoulli_loss(x,
                                     self.pixel_params,
                                     masks=masks,
                                     output_mean=False)
            else:
                nll = mix_logistic_loss(x,
                                        self.pixel_params,
                                        masks=masks,
                                        output_mean=False)

            log_prob_pos = dist.log_prob(epsilon)
            epsilon_pr = (z - self.z_mu_pr) / tf.exp(
                self.z_log_sigma_sq_pr / 2.)
            log_prob_pr = dist.log_prob(epsilon_pr)
            # convert back
            log_prob_pr = tf.stack([
                log_prob_pr[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                                   axis=0)
            log_prob_pos = tf.stack([
                log_prob_pos[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                                    axis=0)
            log_prob_pr = tf.reduce_sum(log_prob_pr, axis=2)
            log_prob_pos = tf.reduce_sum(log_prob_pos, axis=2)
            nll = tf.stack([
                nll[self.batch_size * i:self.batch_size * (i + 1)]
                for i in range(self.num_particles)
            ],
                           axis=0)
            log_likelihood = -nll

            # log_weights = log_prob_pr + log_likelihood - log_prob_pos
            log_weights = log_likelihood
            log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0)
            log_avg_weight = log_sum_weight - tf.log(
                tf.to_float(self.num_particles))
            self.log_avg_weight = log_avg_weight

            normalized_weights = tf.stop_gradient(
                tf.nn.softmax(log_weights, axis=0))
            sq_normalized_weights = tf.square(normalized_weights)

            self.gradients = tf.gradients(
                -tf.reduce_sum(sq_normalized_weights * log_weights, axis=0),
                self.params,
                colocate_gradients_with_ops=True)
Esempio n. 5
0
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large
                encoder_q = conv_encoder_32_q
            else:
                encoder = conv_encoder_32
                decoder = conv_decoder_32
            forward_pixelcnn = forward_pixel_cnn_32_small
            reverse_pixelcnn = reverse_pixel_cnn_32_small
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
                forward_pixelcnn = forward_pixel_cnn_28_binary
                reverse_pixelcnn = reverse_pixel_cnn_28_binary
                encoder_q = conv_encoder_28_binary_q
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope(
            [forward_pixelcnn, reverse_pixelcnn, encoder, decoder, encoder_q],
                **kwargs):
            kwargs_pixelcnn = {
                "nr_resnet": self.nr_resnet,
                "nr_filters": self.nr_filters,
                "nr_logistic_mix": self.nr_logistic_mix,
                "dropout_p": self.dropout_p,
                "bn": False,
            }
            with arg_scope([forward_pixelcnn, reverse_pixelcnn],
                           **kwargs_pixelcnn):
                self.num_particles = 16

                inp = self.x * broadcast_masks_tf(
                    self.input_masks, num_channels=self.num_channels)
                inp += tf.random_uniform(
                    int_shape(inp), -1, 1) * (1 - broadcast_masks_tf(
                        self.input_masks, num_channels=self.num_channels))
                inp = tf.concat([
                    inp,
                    broadcast_masks_tf(self.input_masks, num_channels=1)
                ],
                                axis=-1)

                inputs_pos = tf.concat([
                    self.x,
                    broadcast_masks_tf(tf.ones_like(self.input_masks),
                                       num_channels=1)
                ],
                                       axis=-1)
                inp = tf.concat([inp, inputs_pos], axis=0)

                z_mu, z_log_sigma_sq = encoder(inp, self.z_dim)
                self.z_mu_pr, self.z_mu = z_mu[:self.batch_size], z_mu[
                    self.batch_size:]
                self.z_log_sigma_sq_pr, self.z_log_sigma_sq = z_log_sigma_sq[:self.batch_size], z_log_sigma_sq[
                    self.batch_size:]

                x = tf.tile(self.x, [self.num_particles, 1, 1, 1])
                x_bar = tf.tile(self.x_bar, [self.num_particles, 1, 1, 1])
                input_masks = tf.tile(self.input_masks,
                                      [self.num_particles, 1, 1])
                masks = tf.tile(self.masks, [self.num_particles, 1, 1])

                self.z_mu_pr = tf.tile(self.z_mu_pr, [self.num_particles, 1])
                self.z_log_sigma_sq_pr = tf.tile(self.z_log_sigma_sq_pr,
                                                 [self.num_particles, 1])
                self.z_mu = tf.tile(self.z_mu, [self.num_particles, 1])
                self.z_log_sigma_sq = tf.tile(self.z_log_sigma_sq,
                                              [self.num_particles, 1])

                self.z_mu, self.z_log_sigma_sq = self.z_mu_pr, self.z_log_sigma_sq_pr

                sigma = tf.exp(self.z_log_sigma_sq / 2.)

                self.params = get_trainable_variables(["inference"])

                dist = tf.distributions.Normal(loc=0., scale=1.)
                epsilon = dist.sample(sample_shape=[
                    self.batch_size * self.num_particles, self.z_dim
                ],
                                      seed=None)
                z = self.z_mu + tf.multiply(epsilon, sigma)

                decoded_features = decoder(z, output_features=True)
                r_outputs = reverse_pixelcnn(x, masks, context=None, bn=False)
                cond_features = tf.concat([r_outputs, decoded_features],
                                          axis=-1)
                cond_features = tf.concat([
                    broadcast_masks_tf(input_masks, num_channels=1),
                    cond_features
                ],
                                          axis=-1)

                self.pixel_params = forward_pixelcnn(x_bar,
                                                     cond_features,
                                                     bn=False)

                if self.network_type == 'binary':
                    nll = bernoulli_loss(x,
                                         self.pixel_params,
                                         masks=masks,
                                         output_mean=False)
                else:
                    nll = mix_logistic_loss(x,
                                            self.pixel_params,
                                            masks=masks,
                                            output_mean=False)

                log_prob_pos = dist.log_prob(epsilon)
                epsilon_pr = (z - self.z_mu_pr) / tf.exp(
                    self.z_log_sigma_sq_pr / 2.)
                log_prob_pr = dist.log_prob(epsilon_pr)
                # convert back
                log_prob_pr = tf.stack([
                    log_prob_pr[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                                       axis=0)
                log_prob_pos = tf.stack([
                    log_prob_pos[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                                        axis=0)
                log_prob_pr = tf.reduce_sum(log_prob_pr, axis=2)
                log_prob_pos = tf.reduce_sum(log_prob_pos, axis=2)
                nll = tf.stack([
                    nll[self.batch_size * i:self.batch_size * (i + 1)]
                    for i in range(self.num_particles)
                ],
                               axis=0)
                log_likelihood = -nll

                # log_weights = log_prob_pr + log_likelihood - log_prob_pos
                log_weights = log_likelihood
                log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0)
                log_avg_weight = log_sum_weight - tf.log(
                    tf.to_float(self.num_particles))
                self.log_avg_weight = log_avg_weight

                normalized_weights = tf.stop_gradient(
                    tf.nn.softmax(log_weights, axis=0))
                sq_normalized_weights = tf.square(normalized_weights)
                self.gradients = tf.gradients(-tf.reduce_sum(
                    sq_normalized_weights * log_weights, axis=0),
                                              self.params,
                                              colocate_gradients_with_ops=True)