def reverse_pixel_cnn_28_binary(x, masks, context=None, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("reverse_pixel_cnn_28_binary", counters) x = x * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=None, sh=context, nonlinearity=nonlinearity, dropout_p=dropout_p): with arg_scope([ gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d, up_shifted_deconv2d, up_left_shifted_deconv2d ], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ up_shift( up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=up_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=up_left_shifted_conv2d)) x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters) return x_out
def conditioning_network_28(x, masks, nr_filters, is_training=True, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, counters={}): name = get_name("conditioning_network_28", counters) x = x * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) xs = int_shape(x) x = tf.concat([x, tf.ones(xs[:-1] + [1])], 3) with tf.variable_scope(name): with arg_scope([conv2d, residual_block, dense], nonlinearity=nonlinearity, bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): outputs = conv2d(x, nr_filters, 4, 1, "SAME") for l in range(4): outputs = conv2d(outputs, nr_filters, 4, 1, "SAME") outputs = conv2d(outputs, nr_filters, 1, 1, "SAME", nonlinearity=None, bn=False) return outputs
def context_encoder(contexts, masks, is_training, nr_resnet=5, nr_filters=100, nonlinearity=None, bn=False, kernel_initializer=None, kernel_regularizer=None, counters={}): name = get_name("context_encoder", counters) print("construct", name, "...") x = contexts * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) if bn: print("*** Attention *** using bn in the context encoder\n") with tf.variable_scope(name): with arg_scope([gated_resnet], nonlinearity=nonlinearity, counters=counters): with arg_scope( [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ up_shift( up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left receptive_field = (2, 3) for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=up_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=up_left_shifted_conv2d)) receptive_field = (receptive_field[0] + 1, receptive_field[1] + 2) x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters) print(" * receptive_field", receptive_field) return x_out
def __model(self, network_type="large"): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': num_channels = 1 else: num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.use_prior = tf.placeholder_with_default(False, shape=()) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large decoder = conv_decoder_32_large else: encoder = conv_encoder_32 decoder = conv_decoder_32 forward_pixelcnn = forward_pixel_cnn_32_small reverse_pixelcnn = reverse_pixel_cnn_32_small kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder], **kwargs): kwargs_pixelcnn = { "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "nr_logistic_mix": self.nr_logistic_mix, "dropout_p": self.dropout_p, "bn": False, } with arg_scope([forward_pixelcnn, reverse_pixelcnn], **kwargs_pixelcnn): inputs = self.x if self.input_masks is not None: inputs = inputs * broadcast_masks_tf(self.input_masks, num_channels=3) inputs += tf.random_uniform(int_shape(inputs), -1, 1) * ( 1 - broadcast_masks_tf(self.input_masks, num_channels=3)) inputs = tf.concat([ inputs, broadcast_masks_tf(self.input_masks, num_channels=1) ], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_pr = gaussian_sampler(tf.zeros_like(self.z_mu), tf.ones_like(sigma)) self.z_ph = tf.placeholder_with_default( tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_prior = tf.cast(tf.cast(self.use_prior, tf.int32), tf.float32) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = (use_prior * self.z_pr + (1 - use_prior) * self.z) * ( 1 - use_z_ph) + use_z_ph * self.z_ph decoded_features = decoder(z, output_features=True) r_outputs = reverse_pixelcnn(self.x_bar, self.masks, bn=False) cond_features = tf.concat([r_outputs, decoded_features], axis=-1) self.mix_logistic_params = forward_pixelcnn(self.x_bar, cond_features, bn=False) self.x_hat = mix_logistic_sampler( self.mix_logistic_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self, x, x_bar, is_training, dropout_p, masks, input_masks, network_size="medium"): print("****** Building Graph ******") self.x = x self.x_bar = x_bar self.is_training = is_training self.dropout_p = dropout_p self.masks = masks self.input_masks = input_masks if int_shape(x)[1] == 64: conv_encoder = conv_encoder_64_medium conv_decoder = conv_decoder_64_medium elif int_shape(x)[1] == 32: if network_size == 'medium': conv_encoder = conv_encoder_32_medium conv_decoder = conv_decoder_32_medium elif network_size == 'large': conv_encoder = conv_encoder_32_large conv_decoder = conv_decoder_32_large elif network_size == 'large1': conv_encoder = conv_encoder_32_large1 conv_decoder = conv_decoder_32_large1 else: raise Exception("unknown network type") with arg_scope( [conv_encoder, conv_decoder, context_encoder, cond_pixel_cnn], nonlinearity=self.nonlinearity, bn=self.bn, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer, is_training=self.is_training, counters=self.counters): inputs = self.x if self.input_masks is not None: inputs = inputs * broadcast_masks_tf(self.input_masks, num_channels=3) inputs += tf.random_uniform(int_shape(inputs), -1, 1) * ( 1 - broadcast_masks_tf(self.input_masks, num_channels=3)) inputs = tf.concat([ inputs, broadcast_masks_tf(self.input_masks, num_channels=1) ], axis=-1) self.z_mu, self.z_log_sigma_sq = conv_encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) if self.use_mode == 'train': self.z = gaussian_sampler(self.z_mu, sigma) elif self.use_mode == 'test': self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu)) print("use mode:{0}".format(self.use_mode)) self.decoded_features = conv_decoder(self.z, output_features=True) if self.masks is None: sh = self.decoded_features else: self.encoded_context = context_encoder( self.x, self.masks, bn=False, nr_resnet=self.nr_resnet, nr_filters=self.nr_filters) sh = tf.concat([self.decoded_features, self.encoded_context], axis=-1) if self.input_masks is not None: sh = tf.concat([ broadcast_masks_tf(self.input_masks, num_channels=1), sh ], axis=-1) self.mix_logistic_params = cond_pixel_cnn( self.x_bar, sh=sh, bn=False, dropout_p=self.dropout_p, nr_resnet=self.nr_resnet, nr_filters=self.nr_filters, nr_logistic_mix=self.nr_logistic_mix) self.x_hat = mix_logistic_sampler( self.mix_logistic_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large_bn decoder = conv_decoder_32_large_mixture_logistic encoder_q = conv_encoder_32_q else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': encoder = conv_encoder_28_binary decoder = conv_decoder_28_binary encoder_q = conv_encoder_28_binary_q else: raise Exception("unknown network type") kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([encoder, decoder, encoder_q], **kwargs): inputs = self.x self.num_particles = 16 inputs = inputs * broadcast_masks_tf( self.masks, num_channels=self.num_channels) inputs = tf.concat( [inputs, broadcast_masks_tf(self.masks, num_channels=1)], axis=-1) inputs_pos = tf.concat( [self.x, broadcast_masks_tf(self.masks, num_channels=1)], axis=-1) inputs = tf.concat([inputs, inputs_pos], axis=0) z_mu, z_log_sigma_sq = encoder(inputs, self.z_dim) self.z_mu_pr, self.z_mu = z_mu[:self.batch_size], z_mu[self. batch_size:] self.z_log_sigma_sq_pr, self.z_log_sigma_sq = z_log_sigma_sq[:self.batch_size], z_log_sigma_sq[ self.batch_size:] self.z_mu, self.z_log_sigma_sq = self.z_mu_pr, self.z_log_sigma_sq_pr x = tf.tile(self.x, [self.num_particles, 1, 1, 1]) masks = tf.tile(self.masks, [self.num_particles, 1, 1]) self.z_mu = tf.tile(self.z_mu, [self.num_particles, 1]) self.z_mu_pr = tf.tile(self.z_mu_pr, [self.num_particles, 1]) self.z_log_sigma_sq = tf.tile(self.z_log_sigma_sq, [self.num_particles, 1]) self.z_log_sigma_sq_pr = tf.tile(self.z_log_sigma_sq_pr, [self.num_particles, 1]) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.params = get_trainable_variables(["inference"]) dist = tf.distributions.Normal(loc=0., scale=1.) epsilon = dist.sample(sample_shape=[ self.batch_size * self.num_particles, self.z_dim ], seed=None) z = self.z_mu + tf.multiply(epsilon, sigma) if self.network_type == 'binary': self.pixel_params = decoder(z) else: self.pixel_params = decoder( z, nr_logistic_mix=self.nr_logistic_mix) if self.network_type == 'binary': nll = bernoulli_loss(x, self.pixel_params, masks=masks, output_mean=False) else: nll = mix_logistic_loss(x, self.pixel_params, masks=masks, output_mean=False) log_prob_pos = dist.log_prob(epsilon) epsilon_pr = (z - self.z_mu_pr) / tf.exp( self.z_log_sigma_sq_pr / 2.) log_prob_pr = dist.log_prob(epsilon_pr) # convert back log_prob_pr = tf.stack([ log_prob_pr[self.batch_size * i:self.batch_size * (i + 1)] for i in range(self.num_particles) ], axis=0) log_prob_pos = tf.stack([ log_prob_pos[self.batch_size * i:self.batch_size * (i + 1)] for i in range(self.num_particles) ], axis=0) log_prob_pr = tf.reduce_sum(log_prob_pr, axis=2) log_prob_pos = tf.reduce_sum(log_prob_pos, axis=2) nll = tf.stack([ nll[self.batch_size * i:self.batch_size * (i + 1)] for i in range(self.num_particles) ], axis=0) log_likelihood = -nll # log_weights = log_prob_pr + log_likelihood - log_prob_pos log_weights = log_likelihood log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0) log_avg_weight = log_sum_weight - tf.log( tf.to_float(self.num_particles)) self.log_avg_weight = log_avg_weight normalized_weights = tf.stop_gradient( tf.nn.softmax(log_weights, axis=0)) sq_normalized_weights = tf.square(normalized_weights) self.gradients = tf.gradients( -tf.reduce_sum(sq_normalized_weights * log_weights, axis=0), self.params, colocate_gradients_with_ops=True)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large_bn decoder = conv_decoder_32_large_mixture_logistic else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': encoder = conv_encoder_28_binary decoder = conv_decoder_28_binary else: raise Exception("unknown network type") kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([encoder, decoder], **kwargs): inputs = self.x inputs = inputs * broadcast_masks_tf( self.masks, num_channels=self.num_channels) inputs = tf.concat( [inputs, broadcast_masks_tf(self.masks, num_channels=1)], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_ph = tf.placeholder_with_default(tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = self.z * (1 - use_z_ph) + use_z_ph * self.z_ph if self.network_type == 'binary': self.pixel_params = decoder(z) self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.pixel_params = decoder( z, nr_logistic_mix=self.nr_logistic_mix) self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large_bn decoder = conv_decoder_32_large encoder_q = conv_encoder_32_q else: encoder = conv_encoder_32 decoder = conv_decoder_32 forward_pixelcnn = forward_pixel_cnn_32_small reverse_pixelcnn = reverse_pixel_cnn_32_small elif self.img_size == 28: if self.network_type == 'binary': encoder = conv_encoder_28_binary decoder = conv_decoder_28_binary forward_pixelcnn = forward_pixel_cnn_28_binary reverse_pixelcnn = reverse_pixel_cnn_28_binary encoder_q = conv_encoder_28_binary_q kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope( [forward_pixelcnn, reverse_pixelcnn, encoder, decoder, encoder_q], **kwargs): kwargs_pixelcnn = { "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "nr_logistic_mix": self.nr_logistic_mix, "dropout_p": self.dropout_p, "bn": False, } with arg_scope([forward_pixelcnn, reverse_pixelcnn], **kwargs_pixelcnn): self.num_particles = 16 inp = self.x * broadcast_masks_tf( self.input_masks, num_channels=self.num_channels) inp += tf.random_uniform( int_shape(inp), -1, 1) * (1 - broadcast_masks_tf( self.input_masks, num_channels=self.num_channels)) inp = tf.concat([ inp, broadcast_masks_tf(self.input_masks, num_channels=1) ], axis=-1) inputs_pos = tf.concat([ self.x, broadcast_masks_tf(tf.ones_like(self.input_masks), num_channels=1) ], axis=-1) inp = tf.concat([inp, inputs_pos], axis=0) z_mu, z_log_sigma_sq = encoder(inp, self.z_dim) self.z_mu_pr, self.z_mu = z_mu[:self.batch_size], z_mu[ self.batch_size:] self.z_log_sigma_sq_pr, self.z_log_sigma_sq = z_log_sigma_sq[:self.batch_size], z_log_sigma_sq[ self.batch_size:] x = tf.tile(self.x, [self.num_particles, 1, 1, 1]) x_bar = tf.tile(self.x_bar, [self.num_particles, 1, 1, 1]) input_masks = tf.tile(self.input_masks, [self.num_particles, 1, 1]) masks = tf.tile(self.masks, [self.num_particles, 1, 1]) self.z_mu_pr = tf.tile(self.z_mu_pr, [self.num_particles, 1]) self.z_log_sigma_sq_pr = tf.tile(self.z_log_sigma_sq_pr, [self.num_particles, 1]) self.z_mu = tf.tile(self.z_mu, [self.num_particles, 1]) self.z_log_sigma_sq = tf.tile(self.z_log_sigma_sq, [self.num_particles, 1]) self.z_mu, self.z_log_sigma_sq = self.z_mu_pr, self.z_log_sigma_sq_pr sigma = tf.exp(self.z_log_sigma_sq / 2.) self.params = get_trainable_variables(["inference"]) dist = tf.distributions.Normal(loc=0., scale=1.) epsilon = dist.sample(sample_shape=[ self.batch_size * self.num_particles, self.z_dim ], seed=None) z = self.z_mu + tf.multiply(epsilon, sigma) decoded_features = decoder(z, output_features=True) r_outputs = reverse_pixelcnn(x, masks, context=None, bn=False) cond_features = tf.concat([r_outputs, decoded_features], axis=-1) cond_features = tf.concat([ broadcast_masks_tf(input_masks, num_channels=1), cond_features ], axis=-1) self.pixel_params = forward_pixelcnn(x_bar, cond_features, bn=False) if self.network_type == 'binary': nll = bernoulli_loss(x, self.pixel_params, masks=masks, output_mean=False) else: nll = mix_logistic_loss(x, self.pixel_params, masks=masks, output_mean=False) log_prob_pos = dist.log_prob(epsilon) epsilon_pr = (z - self.z_mu_pr) / tf.exp( self.z_log_sigma_sq_pr / 2.) log_prob_pr = dist.log_prob(epsilon_pr) # convert back log_prob_pr = tf.stack([ log_prob_pr[self.batch_size * i:self.batch_size * (i + 1)] for i in range(self.num_particles) ], axis=0) log_prob_pos = tf.stack([ log_prob_pos[self.batch_size * i:self.batch_size * (i + 1)] for i in range(self.num_particles) ], axis=0) log_prob_pr = tf.reduce_sum(log_prob_pr, axis=2) log_prob_pos = tf.reduce_sum(log_prob_pos, axis=2) nll = tf.stack([ nll[self.batch_size * i:self.batch_size * (i + 1)] for i in range(self.num_particles) ], axis=0) log_likelihood = -nll # log_weights = log_prob_pr + log_likelihood - log_prob_pos log_weights = log_likelihood log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0) log_avg_weight = log_sum_weight - tf.log( tf.to_float(self.num_particles)) self.log_avg_weight = log_avg_weight normalized_weights = tf.stop_gradient( tf.nn.softmax(log_weights, axis=0)) sq_normalized_weights = tf.square(normalized_weights) self.gradients = tf.gradients(-tf.reduce_sum( sq_normalized_weights * log_weights, axis=0), self.params, colocate_gradients_with_ops=True)