Пример #1
0
 def __model(self,
             x,
             x_bar,
             is_training,
             dropout_p,
             masks,
             input_masks,
             network_size="medium"):
     print("******   Building Graph   ******")
     self.x = x
     self.x_bar = x_bar
     self.is_training = is_training
     self.dropout_p = dropout_p
     self.masks = masks
     self.input_masks = input_masks
     if int_shape(x)[1] == 64:
         conv_encoder = conv_encoder_64_medium
         conv_decoder = conv_decoder_64_medium
     elif int_shape(x)[1] == 32:
         if network_size == 'medium':
             conv_encoder = conv_encoder_32_medium
             conv_decoder = conv_decoder_32_medium
         elif network_size == 'large':
             conv_encoder = conv_encoder_32_large
             conv_decoder = conv_decoder_32_large
         elif network_size == 'large1':
             conv_encoder = conv_encoder_32_large1
             conv_decoder = conv_decoder_32_large1
         else:
             raise Exception("unknown network type")
     with arg_scope([conv_encoder, conv_decoder, cond_pixel_cnn],
                    nonlinearity=self.nonlinearity,
                    bn=self.bn,
                    kernel_initializer=self.kernel_initializer,
                    kernel_regularizer=self.kernel_regularizer,
                    is_training=self.is_training,
                    counters=self.counters):
         inputs = self.x
         self.z_mu, self.z_log_sigma_sq = conv_encoder(inputs, self.z_dim)
         sigma = tf.exp(self.z_log_sigma_sq / 2.)
         if self.use_mode == 'train':
             self.z = gaussian_sampler(self.z_mu, sigma)
         elif self.use_mode == 'test':
             self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu))
         print("use mode:{0}".format(self.use_mode))
         self.decoded_features = conv_decoder(self.z, output_features=True)
         sh = self.decoded_features
         self.mix_logistic_params = cond_pixel_cnn(
             self.x_bar,
             sh=sh,
             bn=False,
             dropout_p=self.dropout_p,
             nr_resnet=self.nr_resnet,
             nr_filters=self.nr_filters,
             nr_logistic_mix=self.nr_logistic_mix)
         self.x_hat = mix_logistic_sampler(
             self.mix_logistic_params,
             nr_logistic_mix=self.nr_logistic_mix,
             sample_range=self.sample_range,
             counters=self.counters)
 def __model(self):
     print("******   Building Graph   ******")
     # placeholders
     if self.network_type == 'binary':
         self.num_channels = 1
     else:
         self.num_channels = 3
     self.x = tf.placeholder(tf.float32,
                             shape=(self.batch_size, self.img_size,
                                    self.img_size, self.num_channels))
     self.x_bar = tf.placeholder(tf.float32,
                                 shape=(self.batch_size, self.img_size,
                                        self.img_size, self.num_channels))
     self.is_training = tf.placeholder(tf.bool, shape=())
     self.dropout_p = tf.placeholder(tf.float32, shape=())
     self.masks = tf.placeholder(tf.float32,
                                 shape=(self.batch_size, self.img_size,
                                        self.img_size))
     # choose network size
     if self.img_size == 32:
         if self.network_type == 'large':
             forward_pixel_cnn = forward_pixel_cnn_32
             reverse_pixel_cnn = reverse_pixel_cnn_32
         else:
             raise Exception("unknown network type")
     elif self.img_size == 28:
         if self.network_type == 'binary':
             forward_pixel_cnn = forward_pixel_cnn_28_binary
             reverse_pixel_cnn = reverse_pixel_cnn_28_binary
     #
     kwargs = {
         "nonlinearity": self.nonlinearity,
         "bn": self.bn,
         "kernel_initializer": self.kernel_initializer,
         "kernel_regularizer": self.kernel_regularizer,
         "is_training": self.is_training,
         "counters": self.counters,
     }
     kwargs_basic = kwargs.copy()
     kwargs.update({
         "nr_logistic_mix": self.nr_logistic_mix,
         "nr_resnet": self.nr_resnet,
         "nr_filters": self.nr_filters,
         "dropout_p": self.dropout_p,
     })
     with arg_scope([forward_pixel_cnn, reverse_pixel_cnn], **kwargs):
         inputs = self.x_bar
         r_outputs = reverse_pixel_cnn(inputs, self.masks, bn=False)
         self.pixel_params = forward_pixel_cnn(inputs, r_outputs, bn=False)
         if self.network_type == 'binary':
             self.x_hat = bernoulli_sampler(self.pixel_params,
                                            counters=self.counters)
         else:
             self.x_hat = mix_logistic_sampler(
                 self.pixel_params,
                 nr_logistic_mix=self.nr_logistic_mix,
                 sample_range=self.sample_range,
                 counters=self.counters)
    def __model(self, network_type="large"):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            num_channels = 1
        else:
            num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        self.use_prior = tf.placeholder_with_default(False, shape=())
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large
                decoder = conv_decoder_32_large
            else:
                encoder = conv_encoder_32
                decoder = conv_decoder_32
            forward_pixelcnn = forward_pixel_cnn_32_small
            reverse_pixelcnn = reverse_pixel_cnn_32_small
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder],
                       **kwargs):
            kwargs_pixelcnn = {
                "nr_resnet": self.nr_resnet,
                "nr_filters": self.nr_filters,
                "nr_logistic_mix": self.nr_logistic_mix,
                "dropout_p": self.dropout_p,
                "bn": False,
            }
            with arg_scope([forward_pixelcnn, reverse_pixelcnn],
                           **kwargs_pixelcnn):
                inputs = self.x
                if self.input_masks is not None:
                    inputs = inputs * broadcast_masks_tf(self.input_masks,
                                                         num_channels=3)
                    inputs += tf.random_uniform(int_shape(inputs), -1, 1) * (
                        1 -
                        broadcast_masks_tf(self.input_masks, num_channels=3))
                    inputs = tf.concat([
                        inputs,
                        broadcast_masks_tf(self.input_masks, num_channels=1)
                    ],
                                       axis=-1)

                self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim)
                sigma = tf.exp(self.z_log_sigma_sq / 2.)
                self.z = gaussian_sampler(self.z_mu, sigma)
                self.z_pr = gaussian_sampler(tf.zeros_like(self.z_mu),
                                             tf.ones_like(sigma))

                self.z_ph = tf.placeholder_with_default(
                    tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu))
                self.use_z_ph = tf.placeholder_with_default(False, shape=())

                use_prior = tf.cast(tf.cast(self.use_prior, tf.int32),
                                    tf.float32)
                use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32),
                                   tf.float32)
                z = (use_prior * self.z_pr + (1 - use_prior) * self.z) * (
                    1 - use_z_ph) + use_z_ph * self.z_ph

                decoded_features = decoder(z, output_features=True)
                r_outputs = reverse_pixelcnn(self.x_bar, self.masks, bn=False)
                cond_features = tf.concat([r_outputs, decoded_features],
                                          axis=-1)
                self.mix_logistic_params = forward_pixelcnn(self.x_bar,
                                                            cond_features,
                                                            bn=False)
                self.x_hat = mix_logistic_sampler(
                    self.mix_logistic_params,
                    nr_logistic_mix=self.nr_logistic_mix,
                    sample_range=self.sample_range,
                    counters=self.counters)
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                conditioning_network = conditioning_network_32
                prior_network = forward_pixel_cnn_32
            else:
                raise Exception("unknown network type")
        elif self.img_size == 28:
            if self.network_type == 'binary':
                conditioning_network = conditioning_network_28
                prior_network = forward_pixel_cnn_28_binary
            else:
                raise Exception("unknown network type")

        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([conditioning_network, prior_network], **kwargs):
            kwargs_pixelcnn = {
                "nr_resnet": self.nr_resnet,
                "nr_filters": self.nr_filters,
                "nr_logistic_mix": self.nr_logistic_mix,
                "dropout_p": self.dropout_p,
                "bn": False,
            }
            with arg_scope([prior_network], **kwargs_pixelcnn):
                self.cond_features = conditioning_network(
                    self.x, self.masks, nr_filters=self.nr_filters)
                self.pixel_params = prior_network(self.x_bar,
                                                  self.cond_features,
                                                  bn=False)
                if self.network_type == 'binary':
                    self.x_hat = bernoulli_sampler(self.pixel_params,
                                                   counters=self.counters)
                else:
                    self.x_hat = mix_logistic_sampler(
                        self.pixel_params,
                        nr_logistic_mix=self.nr_logistic_mix,
                        sample_range=self.sample_range,
                        counters=self.counters)
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large_mixture_logistic
            else:
                raise Exception("unknown network type")
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
            else:
                raise Exception("unknown network type")
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([encoder, decoder], **kwargs):
            inputs = self.x
            inputs = inputs * broadcast_masks_tf(
                self.masks, num_channels=self.num_channels)
            inputs = tf.concat(
                [inputs,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim)
            sigma = tf.exp(self.z_log_sigma_sq / 2.)
            self.z = gaussian_sampler(self.z_mu, sigma)

            self.z_ph = tf.placeholder_with_default(tf.zeros_like(self.z_mu),
                                                    shape=int_shape(self.z_mu))
            self.use_z_ph = tf.placeholder_with_default(False, shape=())

            use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32)
            z = self.z * (1 - use_z_ph) + use_z_ph * self.z_ph

            if self.network_type == 'binary':
                self.pixel_params = decoder(z)
                self.x_hat = bernoulli_sampler(self.pixel_params,
                                               counters=self.counters)
            else:
                self.pixel_params = decoder(
                    z, nr_logistic_mix=self.nr_logistic_mix)
                self.x_hat = mix_logistic_sampler(
                    self.pixel_params,
                    nr_logistic_mix=self.nr_logistic_mix,
                    sample_range=self.sample_range,
                    counters=self.counters)