def construct(self, inputs, is_training, dropout_p, nr_resnet=1, nr_filters=50, nonlinearity=tf.nn.relu, bn=False, kernel_initializer=None, kernel_regularizer=None): self.inputs = inputs self.nr_filters = nr_filters self.nonlinearity = nonlinearity self.dropout_p = dropout_p self.bn = bn self.kernel_initializer = kernel_initializer self.kernel_regularizer = kernel_regularizer self.is_training = is_training self.outputs = self._model(inputs, nr_resnet, nr_filters, nonlinearity, dropout_p, bn, kernel_initializer, kernel_regularizer, is_training) self.loss = self._loss(self.inputs, self.outputs) self.x_hat = bernoulli_sampler(self.outputs)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': forward_pixel_cnn = forward_pixel_cnn_32 reverse_pixel_cnn = reverse_pixel_cnn_32 else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': forward_pixel_cnn = forward_pixel_cnn_28_binary reverse_pixel_cnn = reverse_pixel_cnn_28_binary # kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } kwargs_basic = kwargs.copy() kwargs.update({ "nr_logistic_mix": self.nr_logistic_mix, "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "dropout_p": self.dropout_p, }) with arg_scope([forward_pixel_cnn, reverse_pixel_cnn], **kwargs): inputs = self.x_bar r_outputs = reverse_pixel_cnn(inputs, self.masks, bn=False) self.pixel_params = forward_pixel_cnn(inputs, r_outputs, bn=False) if self.network_type == 'binary': self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def construct(self, device, img_size, batch_size, nr_resnet=1, nr_filters=50, nonlinearity=tf.nn.relu, bn=False, kernel_initializer=None, kernel_regularizer=None): self.device = device self.img_size = img_size self.batch_size = batch_size self.nr_resnet = nr_resnet self.nr_filters = nr_filters self.nonlinearity = nonlinearity self.bn = bn self.X = tf.placeholder(tf.float32, shape=(batch_size, img_size, img_size, 1)) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.kernel_initializer = kernel_initializer self.kernel_regularizer = kernel_regularizer self.is_training = tf.placeholder(tf.bool, shape=()) self.outputs = self._model(self.X, self.nr_resnet, self.nr_filters, self.nonlinearity, self.dropout_p, self.bn, self.kernel_initializer, self.kernel_regularizer, self.is_training) self.loss = self._loss(self.X, self.outputs) self.x_hat = bernoulli_sampler(self.outputs)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large_bn decoder = conv_decoder_32_large else: encoder = conv_encoder_32 decoder = conv_decoder_32 forward_pixelcnn = forward_pixel_cnn_32_small reverse_pixelcnn = reverse_pixel_cnn_32_small elif self.img_size == 28: if self.network_type == 'binary': encoder = conv_encoder_28_binary decoder = conv_decoder_28_binary forward_pixelcnn = forward_pixel_cnn_28_binary reverse_pixelcnn = reverse_pixel_cnn_28_binary kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder], **kwargs): kwargs_pixelcnn = { "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "nr_logistic_mix": self.nr_logistic_mix, "dropout_p": self.dropout_p, "bn": False, } with arg_scope([forward_pixelcnn, reverse_pixelcnn], **kwargs_pixelcnn): inputs = self.x if self.input_masks is not None: inputs = inputs * broadcast_masks_tf( self.input_masks, num_channels=self.num_channels) inputs += tf.random_uniform( int_shape(inputs), -1, 1) * (1 - broadcast_masks_tf( self.input_masks, num_channels=self.num_channels)) inputs = tf.concat([ inputs, broadcast_masks_tf(self.input_masks, num_channels=1) ], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_ph = tf.placeholder_with_default( tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = self.z * (1 - use_z_ph) + use_z_ph * self.z_ph decoded_features = decoder(z, output_features=True) # r_outputs = reverse_pixelcnn(self.x, self.masks, context=decoded_features, bn=False) r_outputs = reverse_pixelcnn(self.x, self.masks, context=None, bn=False) cond_features = tf.concat([r_outputs, decoded_features], axis=-1) cond_features = tf.concat([ broadcast_masks_tf(self.input_masks, num_channels=1), cond_features ], axis=-1) self.pixel_params = forward_pixelcnn(self.x_bar, cond_features, bn=False) if self.network_type == 'binary': self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': conditioning_network = conditioning_network_32 prior_network = forward_pixel_cnn_32 else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': conditioning_network = conditioning_network_28 prior_network = forward_pixel_cnn_28_binary else: raise Exception("unknown network type") kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([conditioning_network, prior_network], **kwargs): kwargs_pixelcnn = { "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "nr_logistic_mix": self.nr_logistic_mix, "dropout_p": self.dropout_p, "bn": False, } with arg_scope([prior_network], **kwargs_pixelcnn): self.cond_features = conditioning_network( self.x, self.masks, nr_filters=self.nr_filters) self.pixel_params = prior_network(self.x_bar, self.cond_features, bn=False) if self.network_type == 'binary': self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large_bn decoder = conv_decoder_32_large_mixture_logistic else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': encoder = conv_encoder_28_binary decoder = conv_decoder_28_binary else: raise Exception("unknown network type") kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([encoder, decoder], **kwargs): inputs = self.x inputs = inputs * broadcast_masks_tf( self.masks, num_channels=self.num_channels) inputs = tf.concat( [inputs, broadcast_masks_tf(self.masks, num_channels=1)], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_ph = tf.placeholder_with_default(tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = self.z * (1 - use_z_ph) + use_z_ph * self.z_ph if self.network_type == 'binary': self.pixel_params = decoder(z) self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.pixel_params = decoder( z, nr_logistic_mix=self.nr_logistic_mix) self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)