def deconv2d(inputs, num_filters, kernel_size, strides=1, padding='SAME', nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): #def deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): filter_size = kernel_size pad = padding x = inputs stride = strides xs = int_shape(x) name = get_name('deconv2d', counters) if pad=='SAME': target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters] else: target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters] with tf.variable_scope(name): V = tf.get_variable('V', shape=filter_size+[num_filters,int(x.get_shape()[-1])], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.05), trainable=True) g = tf.get_variable('g', shape=[num_filters], dtype=tf.float32, initializer=tf.constant_initializer(1.), trainable=True) b = tf.get_variable('b', shape=[num_filters], dtype=tf.float32, initializer=tf.constant_initializer(0.), trainable=True) # use weight normalization (Salimans & Kingma, 2016) W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3]) # calculate convolutional layer output x = tf.nn.conv2d_transpose(x, W, target_shape, [1] + stride + [1], padding=pad) x = tf.nn.bias_add(x, b) outputs = x if bn: outputs = tf.layers.batch_normalization(outputs, training=is_training) if nonlinearity is not None: outputs = nonlinearity(outputs) print(" + deconv2d", int_shape(inputs), int_shape(outputs), nonlinearity, bn) return outputs
def deconv2d(inputs, num_filters, kernel_size, strides=1, padding='SAME', nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] if isinstance(strides, int): strides = [strides, strides] outputs = deconv2d_openai(inputs, num_filters, filter_size=kernel_size, stride=strides, pad=padding, counters=counters, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer) if bn: outputs = tf.layers.batch_normalization(outputs, training=is_training) if nonlinearity is not None: outputs = nonlinearity(outputs) print(" + deconv2d", int_shape(inputs), int_shape(outputs), nonlinearity, bn) return outputs
def __model(self, x, is_training): print("****** Building Graph ******") self.x = x self.is_training = is_training if int_shape(x)[1] == 64: encoder = conv_encoder_64_medium decoder = conv_decoder_64_medium elif int_shape(x)[1] == 32: encoder = conv_encoder_32_medium decoder = conv_decoder_32_medium with arg_scope([encoder, decoder], nonlinearity=self.nonlinearity, bn=self.bn, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer, is_training=self.is_training, counters=self.counters): self.z_mu, self.z_log_sigma_sq = encoder(x, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) if self.use_mode == 'train': self.z = gaussian_sampler(self.z_mu, sigma) elif self.use_mode == 'test': self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu)) print("use mode:{0}".format(self.use_mode)) self.x_hat = decoder(self.z)
def __model(self, x, x_bar, is_training, dropout_p, masks, input_masks, network_size="medium"): print("****** Building Graph ******") self.x = x self.x_bar = x_bar self.is_training = is_training self.dropout_p = dropout_p self.masks = masks self.input_masks = input_masks if int_shape(x)[1] == 64: conv_encoder = conv_encoder_64_medium conv_decoder = conv_decoder_64_medium elif int_shape(x)[1] == 32: if network_size == 'medium': conv_encoder = conv_encoder_32_medium conv_decoder = conv_decoder_32_medium elif network_size == 'large': conv_encoder = conv_encoder_32_large conv_decoder = conv_decoder_32_large elif network_size == 'large1': conv_encoder = conv_encoder_32_large1 conv_decoder = conv_decoder_32_large1 else: raise Exception("unknown network type") with arg_scope([conv_encoder, conv_decoder, cond_pixel_cnn], nonlinearity=self.nonlinearity, bn=self.bn, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer, is_training=self.is_training, counters=self.counters): inputs = self.x self.z_mu, self.z_log_sigma_sq = conv_encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) if self.use_mode == 'train': self.z = gaussian_sampler(self.z_mu, sigma) elif self.use_mode == 'test': self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu)) print("use mode:{0}".format(self.use_mode)) self.decoded_features = conv_decoder(self.z, output_features=True) sh = self.decoded_features self.mix_logistic_params = cond_pixel_cnn( self.x_bar, sh=sh, bn=False, dropout_p=self.dropout_p, nr_resnet=self.nr_resnet, nr_filters=self.nr_filters, nr_logistic_mix=self.nr_logistic_mix) self.x_hat = mix_logistic_sampler( self.mix_logistic_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def conv2d(inputs, num_filters, kernel_size, strides=1, padding='SAME', nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False): outputs = tf.layers.conv2d(inputs, num_filters, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer) if nonlinearity is not None: outputs = nonlinearity(outputs) if bn: outputs = tf.layers.batch_normalization(outputs, training=is_training) print(" + conv2d", int_shape(inputs), int_shape(outputs), nonlinearity, bn) return outputs
def gaussian_sampler(loc, scale, counters={}): name = get_name("gaussian_sampler", counters) print("construct", name, "...") with tf.variable_scope(name): dist = tf.distributions.Normal(loc=0., scale=1.) z = dist.sample(sample_shape=int_shape(loc), seed=None) z = loc + tf.multiply(z, scale) print(" + gaussian_sampler", int_shape(z)) return z
def bernoulli_loss(x, l, sum_all=True): xs = int_shape(x) ls = int_shape(l) lse = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=l), 3) if sum_all: return tf.reduce_sum(lse) else: return tf.reduce_sum(lse, [1, 2])
def dense(inputs, num_outputs, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False): inputs_shape = int_shape(inputs) assert len(inputs_shape)==2, "inputs should be flattened first" outputs = tf.layers.dense(inputs, num_outputs, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer) if nonlinearity is not None: outputs = nonlinearity(outputs) if bn: outputs = tf.layers.batch_normalization(outputs, training=is_training) print(" + dense", int_shape(inputs), int_shape(outputs), nonlinearity, bn) return outputs
def bernoulli_loss(x, l, masks=None, output_mean=True): xs = int_shape(x) ls = int_shape(l) lse = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=l), 3) if masks is not None: assert lse.shape == masks.shape, "shape of masks does not match the log_sum_exp outputs" lse *= (1 - masks) if output_mean: return tf.reduce_sum(lse) else: return tf.reduce_sum(lse, [1, 2])
def discretized_mix_logistic_loss(x,l,sum_all=True, masks=None): """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100) nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics logit_probs = l[:,:,:,:nr_mix] l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3]) means = l[:,:,:,:,:nr_mix] log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.) coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix]) x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix]) m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix]) means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3) centered_x = x - means inv_stdv = tf.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1./255.) cdf_plus = tf.nn.sigmoid(plus_in) min_in = inv_stdv * (centered_x - 1./255.) cdf_min = tf.nn.sigmoid(min_in) log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling) log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling) cdf_delta = cdf_plus - cdf_min # probability for all other cases mid_in = inv_stdv * centered_x log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs) lse = log_sum_exp(log_probs) if masks is not None: assert lse.shape==masks.shape, "shape of masks does not match the log_sum_exp outputs" lse *= (1 - masks) if sum_all: return -tf.reduce_sum(lse) else: return -tf.reduce_sum(lse,[1,2])
def conditioning_network_28(x, masks, nr_filters, is_training=True, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, counters={}): name = get_name("conditioning_network_28", counters) x = x * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) xs = int_shape(x) x = tf.concat([x, tf.ones(xs[:-1] + [1])], 3) with tf.variable_scope(name): with arg_scope([conv2d, residual_block, dense], nonlinearity=nonlinearity, bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): outputs = conv2d(x, nr_filters, 4, 1, "SAME") for l in range(4): outputs = conv2d(outputs, nr_filters, 4, 1, "SAME") outputs = conv2d(outputs, nr_filters, 1, 1, "SAME", nonlinearity=None, bn=False) return outputs
def conditional_decoder(x, z, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("conditional_decoder", counters) print("construct", name, "...") with tf.variable_scope(name): with arg_scope([dense], nonlinearity=nonlinearity, bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training): size = 256 batch_size = tf.shape(x)[0] x = tf.tile(x, tf.stack([1, int_shape(z)[1]])) z = tf.tile(z, tf.stack([batch_size, 1])) # xz = x + z * tf.get_variable(name="coeff", shape=(), dtype=tf.float32, initializer=tf.constant_initializer(2.0)) xz = x a = dense(xz, size, nonlinearity=None) + dense( z, size, nonlinearity=None) outputs = tf.nn.tanh(a) * tf.sigmoid(a) for k in range(4): a = dense(outputs, size, nonlinearity=None) + dense( z, size, nonlinearity=None) outputs = tf.nn.tanh(a) * tf.sigmoid(a) outputs = dense(outputs, 1, nonlinearity=None, bn=False) outputs = tf.reshape(outputs, shape=(batch_size, )) return outputs
def estimate_mi_tc_dwkld(z, z_mu, z_log_sigma_sq, N=2e5): # computational cheaper to compute them together z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq)) log_probs = [] batch_size, z_dim = int_shape(z_mu) z_b = tf.stack([z for i in range(batch_size)], axis=0) z_mu_b = tf.stack([z_mu for i in range(batch_size)], axis=1) z_sigma_b = tf.stack([z_sigma for i in range(batch_size)], axis=1) z_norm = (z_b - z_mu_b) / z_sigma_b dist = tf.distributions.Normal(loc=0., scale=1.) log_probs = dist.log_prob(z_norm) ratio = np.log(float(N - 1) / (batch_size - 1)) * np.ones( (batch_size, batch_size)) np.fill_diagonal(ratio, 0.) ratio_b = np.stack([ratio for i in range(z_dim)], axis=-1) lse_sum = tf.reduce_mean( log_sum_exp(tf.reduce_sum(log_probs, axis=-1) + ratio, axis=0)) sum_lse = tf.reduce_mean( tf.reduce_sum(log_sum_exp(log_probs + ratio_b, axis=0), axis=-1)) lse_sum -= tf.log(float(N)) sum_lse -= tf.log(float(N)) * float(z_dim) kld = compute_gaussian_kld(z_mu, z_log_sigma_sq) cond_entropy = tf.reduce_mean( compute_gaussian_entropy(z_mu, z_log_sigma_sq)) mi = -lse_sum - cond_entropy tc = lse_sum - sum_lse dwkld = sum_lse + (kld + cond_entropy) return mi, tc, dwkld
def gated_resnet(x, a=None, gh=None, sh=None, nonlinearity=tf.nn.elu, conv=conv2d, dropout_p=0.0, counters={}, **kwargs): name = get_name("gated_resnet", counters) print("construct", name, "...") xs = int_shape(x) num_filters = xs[-1] kwargs["counters"] = counters with arg_scope([conv], **kwargs): c1 = conv(nonlinearity(x), num_filters) if a is not None: # add short-cut connection if auxiliary input 'a' is given c1 += nin(nonlinearity(a), num_filters) c1 = nonlinearity(c1) c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p) c2 = conv(c1, num_filters * 2) # add projection of h vector if included: conditional generation if sh is not None: c2 += nin(sh, 2 * num_filters, nonlinearity=nonlinearity) if gh is not None: # haven't finished this part pass a, b = tf.split(c2, 2, 3) c3 = a * tf.nn.sigmoid(b) return x + c3
def estimate_mi(z, z_mu, z_log_sigma_sq, N=200000): batch_size, z_dim = int_shape(z_mu) lse_sum, sum_lse = estimate_log_probs(z, z_mu, z_log_sigma_sq, N=N) lse_sum -= tf.log(float(N)) cond_entropy = tf.reduce_mean( compute_gaussian_entropy(z_mu, z_log_sigma_sq)) return -lse_sum - cond_entropy
def reverse_pixel_cnn_28_binary(x, masks, context=None, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("reverse_pixel_cnn_28_binary", counters) x = x * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=None, sh=context, nonlinearity=nonlinearity, dropout_p=dropout_p): with arg_scope([ gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d, up_shifted_deconv2d, up_left_shifted_deconv2d ], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ up_shift( up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=up_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=up_left_shifted_conv2d)) x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters) return x_out
def __loss(self, reg): print("****** Compute Loss ******") self.mmd, self.kld, self.mi, self.tc, self.dwkld = [ None for i in range(5) ] self.gamma, self.dwmmd = 1e3, None ## hard coded, experimental self.mmdtc = None self.loss_ae = mix_logistic_loss(self.x, self.mix_logistic_params, masks=self.masks) if reg is None: self.loss_reg = 0 elif reg == 'kld': self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq) self.loss_reg = self.beta * tf.maximum(self.lam, self.kld) elif reg == 'mmd': # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z) self.mmd = estimate_mmd( tf.random_normal(tf.stack([256, self.z_dim])), self.z) self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd) elif reg == 'tc': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=self.N) self.loss_reg = self.mi + self.beta * self.tc + self.dwkld elif reg == 'info-tc': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=self.N) self.loss_reg = self.beta * self.tc + self.dwkld elif reg == 'tc-dwmmd': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=self.N) self.dwmmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z, is_dimention_wise=True) self.loss_reg = self.beta * self.tc + self.dwmmd * self.gamma elif reg == 'mmd-tc': self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z) self.mmdtc = estimate_mmdtc(self.z, self.random_indices) self.loss_reg = (self.mmd + self.beta * self.mmdtc) * 1e5 self.mi = estimate_mi(self.z, self.z_mu, self.z_log_sigma_sq, N=200000) print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta, self.lam)) self.loss = self.loss_ae + self.loss_reg
def cond_pixel_cnn(x, gh=None, sh=None, nonlinearity=tf.nn.elu, nr_resnet=5, nr_filters=100, nr_logistic_mix=10, bn=False, dropout_p=0.0, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("conv_pixel_cnn", counters) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=gh, sh=sh, nonlinearity=nonlinearity, dropout_p=dropout_p, counters=counters): with arg_scope( [gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ down_shift( down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left receptive_field = (2, 3) for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=down_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) receptive_field = (receptive_field[0] + 1, receptive_field[1] + 2) x_out = nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix) print(" * receptive_field", receptive_field) return x_out
def estimate_mmdtc(y, indices): batch_size, z_dim = int_shape(y) batch_size = batch_size // 2 x, y = y[:batch_size], y[batch_size:] ys = tf.unstack(y, axis=1) for i in range(z_dim): ys[i] = tf.gather(ys[i], indices[:, i]) y = tf.stack(ys, axis=1) return estimate_mmd(x, y)
def estimate_dwkld(z, z_mu, z_log_sigma_sq, N=200000): batch_size, z_dim = int_shape(z_mu) lse_sum, sum_lse = estimate_log_probs(z, z_mu, z_log_sigma_sq, N=N) sum_lse -= tf.log(float(N)) * float(z_dim) kld = compute_gaussian_kld(z_mu, z_log_sigma_sq) cond_entropy = tf.reduce_mean( compute_gaussian_entropy(z_mu, z_log_sigma_sq)) nll_prior = kld + cond_entropy return sum_lse + nll_prior
def context_encoder(contexts, masks, is_training, nr_resnet=5, nr_filters=100, nonlinearity=None, bn=False, kernel_initializer=None, kernel_regularizer=None, counters={}): name = get_name("context_encoder", counters) print("construct", name, "...") x = contexts * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) if bn: print("*** Attention *** using bn in the context encoder\n") with tf.variable_scope(name): with arg_scope([gated_resnet], nonlinearity=nonlinearity, counters=counters): with arg_scope( [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ up_shift( up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left receptive_field = (2, 3) for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=up_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=up_left_shifted_conv2d)) receptive_field = (receptive_field[0] + 1, receptive_field[1] + 2) x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters) print(" * receptive_field", receptive_field) return x_out
def dense(inputs, num_outputs, W=None, b=None, nonlinearity=None, bn=False, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): ''' fully connected layer ''' name = get_name('dense', counters) with tf.variable_scope(name): if W is None: W = tf.get_variable( 'W', shape=[int(inputs.get_shape()[1]), num_outputs], dtype=tf.float32, trainable=True, initializer=kernel_initializer, regularizer=kernel_regularizer) if b is None: b = tf.get_variable('b', shape=[num_outputs], dtype=tf.float32, trainable=True, initializer=tf.constant_initializer(0.), regularizer=None) outputs = tf.matmul(inputs, W) + tf.reshape(b, [1, num_outputs]) if bn: outputs = tf.layers.batch_normalization(outputs, training=is_training) if nonlinearity is not None: outputs = nonlinearity(outputs) print(" + dense", int_shape(inputs), int_shape(outputs), nonlinearity, bn) return outputs
def up_left_shifted_deconv2d(x, num_filters, filter_size=[2, 2], strides=[1, 1], **kwargs): x = deconv2d(x, num_filters, kernel_size=filter_size, strides=strides, padding='VALID', **kwargs) xs = int_shape(x) return x[:, (xs[1] - filter_size[0] + 1):, (xs[2] - filter_size[1] + 1):, :]
def down_shifted_deconv2d(x, num_filters, filter_size=[2, 3], strides=[1, 1], **kwargs): x = deconv2d(x, num_filters, kernel_size=filter_size, strides=strides, padding='VALID', **kwargs) xs = int_shape(x) return x[:, :(xs[1] - filter_size[0] + 1), int((filter_size[1] - 1) / 2):(xs[2] - int((filter_size[1] - 1) / 2)), :]
def deconv2d_openai(x, num_filters, filter_size=[3, 3], stride=[1, 1], pad='SAME', nonlinearity=None, kernel_initializer=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): name = get_name("deconv2d", counters) xs = int_shape(x) if pad == 'SAME': target_shape = [ xs[0], xs[1] * stride[0], xs[2] * stride[1], num_filters ] else: target_shape = [ xs[0], xs[1] * stride[0] + filter_size[0] - 1, xs[2] * stride[1] + filter_size[1] - 1, num_filters ] with tf.variable_scope(name): V = tf.get_variable('V', shape=filter_size + [num_filters, int(x.get_shape()[-1])], dtype=tf.float32, initializer=kernel_initializer, trainable=True) b = tf.get_variable('b', shape=[num_filters], dtype=tf.float32, initializer=tf.constant_initializer(0.), trainable=True) W = V x = tf.nn.conv2d_transpose(x, W, target_shape, [1] + stride + [1], padding=pad) x = tf.nn.bias_add(x, b) return x
def sample_from_discretized_mix_logistic(l, nr_mix, epsilon=1e-5): ls = int_shape(l) xs = ls[:-1] + [3] # unpack parameters logit_probs = l[:, :, :, :nr_mix] l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3]) # sample mixture indicator from softmax sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32) sel = tf.reshape(sel, xs[:-1] + [1,nr_mix]) # select logistic parameters means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4) log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.) coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4) # sample from logistic & clip to interval # we don't actually round to the nearest 8bit value when sampling u = tf.random_uniform(means.get_shape(), minval=epsilon, maxval=1. - epsilon) x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u)) x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.) x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.) x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.) return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3)
def __loss(self, reg): print("****** Compute Loss ******") self.mmd, self.kld, self.mi, self.tc, self.dwkld = [ None for i in range(5) ] self.loss_ae = gaussian_recons_loss(self.x, self.x_hat) if reg is None: self.loss_reg = 0 elif reg == 'kld': self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq) self.loss_reg = self.beta * tf.maximum(self.lam, self.kld) elif reg == 'mmd': self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z) self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd) elif reg == 'tc': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=self.N) self.loss_reg = self.mi + self.beta * self.tc + self.dwkld print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta, self.lam)) self.loss = self.loss_ae + self.loss_reg
def estimate_tc(z, z_mu, z_log_sigma_sq, N=200000): z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq)) log_probs = [] batch_size, z_dim = int_shape(z_mu) z_b = tf.stack([z for i in range(batch_size)], axis=0) z_mu_b = tf.stack([z_mu for i in range(batch_size)], axis=1) z_sigma_b = tf.stack([z_sigma for i in range(batch_size)], axis=1) z_norm = (z_b - z_mu_b) / z_sigma_b dist = tf.distributions.Normal(loc=0., scale=1.) log_probs = dist.log_prob(z_norm) ratio = np.log(float(N - 1) / (batch_size - 1)) * np.ones( (batch_size, batch_size)) np.fill_diagonal(ratio, 0.) ratio_b = np.stack([ratio for i in range(z_dim)], axis=-1) lse_sum = tf.reduce_mean( log_sum_exp(tf.reduce_sum(log_probs, axis=-1) + ratio, axis=0)) sum_lse = tf.reduce_mean( tf.reduce_sum(log_sum_exp(log_probs + ratio_b, axis=0), axis=-1)) return lse_sum - sum_lse + tf.log(float(N)) * (float(z_dim) - 1)
def __model(self, network_type="large"): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': num_channels = 1 else: num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.use_prior = tf.placeholder_with_default(False, shape=()) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large decoder = conv_decoder_32_large else: encoder = conv_encoder_32 decoder = conv_decoder_32 forward_pixelcnn = forward_pixel_cnn_32_small reverse_pixelcnn = reverse_pixel_cnn_32_small kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder], **kwargs): kwargs_pixelcnn = { "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "nr_logistic_mix": self.nr_logistic_mix, "dropout_p": self.dropout_p, "bn": False, } with arg_scope([forward_pixelcnn, reverse_pixelcnn], **kwargs_pixelcnn): inputs = self.x if self.input_masks is not None: inputs = inputs * broadcast_masks_tf(self.input_masks, num_channels=3) inputs += tf.random_uniform(int_shape(inputs), -1, 1) * ( 1 - broadcast_masks_tf(self.input_masks, num_channels=3)) inputs = tf.concat([ inputs, broadcast_masks_tf(self.input_masks, num_channels=1) ], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_pr = gaussian_sampler(tf.zeros_like(self.z_mu), tf.ones_like(sigma)) self.z_ph = tf.placeholder_with_default( tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_prior = tf.cast(tf.cast(self.use_prior, tf.int32), tf.float32) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = (use_prior * self.z_pr + (1 - use_prior) * self.z) * ( 1 - use_z_ph) + use_z_ph * self.z_ph decoded_features = decoder(z, output_features=True) r_outputs = reverse_pixelcnn(self.x_bar, self.masks, bn=False) cond_features = tf.concat([r_outputs, decoded_features], axis=-1) self.mix_logistic_params = forward_pixelcnn(self.x_bar, cond_features, bn=False) self.x_hat = mix_logistic_sampler( self.mix_logistic_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def forward_pixel_cnn_32(x, context, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("forward_pixel_cnn_32", counters) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=None, sh=None, nonlinearity=nonlinearity, dropout_p=dropout_p): with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): xs = int_shape(x) x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on u_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append(gated_resnet(u_list[-1], sh=context, conv=down_shifted_conv2d)) ul_list.append(gated_resnet(ul_list[-1], u_list[-1], sh=context, conv=down_right_shifted_conv2d)) u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2])) ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2])) for rep in range(nr_resnet): u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2])) ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2])) for rep in range(nr_resnet): u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) # /////// down pass //////// u = u_list.pop() ul = ul_list.pop() for rep in range(nr_resnet): u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d) u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2]) ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2]) for rep in range(nr_resnet+1): u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d) u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2]) ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2]) for rep in range(nr_resnet+1): u = gated_resnet(u, u_list.pop(), sh=None, conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), sh=None, conv=down_right_shifted_conv2d) x_out = nin(tf.nn.elu(ul),10*nr_logistic_mix) assert len(u_list) == 0 assert len(ul_list) == 0 return x_out