def __model(self, x, x_bar, is_training, dropout_p, masks, input_masks, network_size="medium"): print("****** Building Graph ******") self.x = x self.x_bar = x_bar self.is_training = is_training self.dropout_p = dropout_p self.masks = masks self.input_masks = input_masks if int_shape(x)[1] == 64: conv_encoder = conv_encoder_64_medium conv_decoder = conv_decoder_64_medium elif int_shape(x)[1] == 32: if network_size == 'medium': conv_encoder = conv_encoder_32_medium conv_decoder = conv_decoder_32_medium elif network_size == 'large': conv_encoder = conv_encoder_32_large conv_decoder = conv_decoder_32_large elif network_size == 'large1': conv_encoder = conv_encoder_32_large1 conv_decoder = conv_decoder_32_large1 else: raise Exception("unknown network type") with arg_scope([conv_encoder, conv_decoder, cond_pixel_cnn], nonlinearity=self.nonlinearity, bn=self.bn, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer, is_training=self.is_training, counters=self.counters): inputs = self.x self.z_mu, self.z_log_sigma_sq = conv_encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) if self.use_mode == 'train': self.z = gaussian_sampler(self.z_mu, sigma) elif self.use_mode == 'test': self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu)) print("use mode:{0}".format(self.use_mode)) self.decoded_features = conv_decoder(self.z, output_features=True) sh = self.decoded_features self.mix_logistic_params = cond_pixel_cnn( self.x_bar, sh=sh, bn=False, dropout_p=self.dropout_p, nr_resnet=self.nr_resnet, nr_filters=self.nr_filters, nr_logistic_mix=self.nr_logistic_mix) self.x_hat = mix_logistic_sampler( self.mix_logistic_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': forward_pixel_cnn = forward_pixel_cnn_32 reverse_pixel_cnn = reverse_pixel_cnn_32 else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': forward_pixel_cnn = forward_pixel_cnn_28_binary reverse_pixel_cnn = reverse_pixel_cnn_28_binary # kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } kwargs_basic = kwargs.copy() kwargs.update({ "nr_logistic_mix": self.nr_logistic_mix, "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "dropout_p": self.dropout_p, }) with arg_scope([forward_pixel_cnn, reverse_pixel_cnn], **kwargs): inputs = self.x_bar r_outputs = reverse_pixel_cnn(inputs, self.masks, bn=False) self.pixel_params = forward_pixel_cnn(inputs, r_outputs, bn=False) if self.network_type == 'binary': self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self, network_type="large"): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': num_channels = 1 else: num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.use_prior = tf.placeholder_with_default(False, shape=()) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large decoder = conv_decoder_32_large else: encoder = conv_encoder_32 decoder = conv_decoder_32 forward_pixelcnn = forward_pixel_cnn_32_small reverse_pixelcnn = reverse_pixel_cnn_32_small kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder], **kwargs): kwargs_pixelcnn = { "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "nr_logistic_mix": self.nr_logistic_mix, "dropout_p": self.dropout_p, "bn": False, } with arg_scope([forward_pixelcnn, reverse_pixelcnn], **kwargs_pixelcnn): inputs = self.x if self.input_masks is not None: inputs = inputs * broadcast_masks_tf(self.input_masks, num_channels=3) inputs += tf.random_uniform(int_shape(inputs), -1, 1) * ( 1 - broadcast_masks_tf(self.input_masks, num_channels=3)) inputs = tf.concat([ inputs, broadcast_masks_tf(self.input_masks, num_channels=1) ], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_pr = gaussian_sampler(tf.zeros_like(self.z_mu), tf.ones_like(sigma)) self.z_ph = tf.placeholder_with_default( tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_prior = tf.cast(tf.cast(self.use_prior, tf.int32), tf.float32) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = (use_prior * self.z_pr + (1 - use_prior) * self.z) * ( 1 - use_z_ph) + use_z_ph * self.z_ph decoded_features = decoder(z, output_features=True) r_outputs = reverse_pixelcnn(self.x_bar, self.masks, bn=False) cond_features = tf.concat([r_outputs, decoded_features], axis=-1) self.mix_logistic_params = forward_pixelcnn(self.x_bar, cond_features, bn=False) self.x_hat = mix_logistic_sampler( self.mix_logistic_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': conditioning_network = conditioning_network_32 prior_network = forward_pixel_cnn_32 else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': conditioning_network = conditioning_network_28 prior_network = forward_pixel_cnn_28_binary else: raise Exception("unknown network type") kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([conditioning_network, prior_network], **kwargs): kwargs_pixelcnn = { "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "nr_logistic_mix": self.nr_logistic_mix, "dropout_p": self.dropout_p, "bn": False, } with arg_scope([prior_network], **kwargs_pixelcnn): self.cond_features = conditioning_network( self.x, self.masks, nr_filters=self.nr_filters) self.pixel_params = prior_network(self.x_bar, self.cond_features, bn=False) if self.network_type == 'binary': self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large_bn decoder = conv_decoder_32_large_mixture_logistic else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': encoder = conv_encoder_28_binary decoder = conv_decoder_28_binary else: raise Exception("unknown network type") kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([encoder, decoder], **kwargs): inputs = self.x inputs = inputs * broadcast_masks_tf( self.masks, num_channels=self.num_channels) inputs = tf.concat( [inputs, broadcast_masks_tf(self.masks, num_channels=1)], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_ph = tf.placeholder_with_default(tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = self.z * (1 - use_z_ph) + use_z_ph * self.z_ph if self.network_type == 'binary': self.pixel_params = decoder(z) self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.pixel_params = decoder( z, nr_logistic_mix=self.nr_logistic_mix) self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)