def model_spec(x): counters = {} xs = int_shape(x) sum_log_det_jacobians = tf.zeros(xs[0]) # corrupt data (Tapani Raiko's dequantization) y = x * 0.5 + 0.5 y = y * 255.0 corruption_level = 1.0 y = y + corruption_level * tf.random_uniform(xs) y = y / (255.0 + corruption_level) #model logit instead of the x itself alpha = 1e-5 y = y * (1 - alpha) + alpha * 0.5 jac = tf.reduce_sum(-tf.log(y) - tf.log(1 - y), [1, 2, 3]) y = tf.log(y) - tf.log(1 - y) sum_log_det_jacobians += jac if len(layers) == 0: construct_model_spec() # construct forward pass z = None jac = sum_log_det_jacobians for layer in layers: y, jac, z = layer.forward_and_jacobian(y, jac, z) z = tf.concat([z, y], 3) # record dimension of the final variable global final_latent_dimension final_latent_dimension = real_nvp_layers.int_shape(z) return z, jac
def forward_and_jacobian(self, x, sum_log_det_jacobians, z): xs = utils.int_shape(x) assert xs[1] % 2 == 0 and xs[2] % 2 == 0 y = tf.space_to_depth(x, 2) if z is not None: z = tf.space_to_depth(z, 2) return y, sum_log_det_jacobians, z
def backward(self, y, z): ys = utils.int_shape(y) assert ys[3] % 4 == 0 x = tf.depth_to_space(y, 2) if z is not None: z = tf.depth_to_space(z, 2) return x, z
def backward(self, y, z): with tf.variable_scope(self.name, reuse=True): ys = utils.int_shape(y) b = self.get_mask(ys, self.mask_type) y1 = y * b l, m = self.function_l_m(y1, b) x = y1 + tf.multiply(y * (-b + 1.0) - m, tf.exp(-l)) return x, z
def function_l_m(self, x, mask, name='function_l_m'): with tf.variable_scope(name): channel = 64 padding = 'SAME+' xs = utils.int_shape(x) kernel_h = 3 kernel_w = 3 input_channel = xs[3] y = x y, _ = utils.batch_norm(y) weights_shape = [1, 1, input_channel, channel] weights = self.get_normalized_weights("weights_input", weights_shape) y = tf.nn.conv2d(y, weights, [1, 1, 1, 1], padding=padding) y, _ = utils.batch_norm(y) y = tf.nn.relu(y) skip = y # Residual blocks num_residual_blocks = 8 for r in range(num_residual_blocks): weights_shape = [kernel_h, kernel_w, channel, channel] weights = self.get_normalized_weights("weights%d_1" % r, weights_shape) y = tf.nn.conv2d(y, weights, [1, 1, 1, 1], padding=padding) y, _ = utils.batch_norm(y) y = tf.nn.relu(y) weights_shape = [kernel_h, kernel_w, channel, channel] weights = self.get_normalized_weights("weights%d_2" % r, weights_shape) y = tf.nn.conv2d(y, weights, [1, 1, 1, 1], padding=padding) y, _ = utils.batch_norm(y) y += skip y = tf.nn.relu(y) skip = y # 1x1 convolution for reducing dimension weights = self.get_normalized_weights( "weights_output", [1, 1, channel, input_channel * 2]) y = tf.nn.conv2d(y, weights, [1, 1, 1, 1], padding=padding) # For numerical stability, apply tanh and then scale y = tf.tanh(y) scale_factor = self.get_normalized_weights("weights_tanh_scale", [1]) y *= scale_factor # The first half defines the l function # The second half defines the m function l = y[:, :, :, :input_channel] * (-mask + 1) m = y[:, :, :, input_channel:] * (-mask + 1) return l, m
def backward(self, y, z): # At scale 0, 1/2 of the original dimensions are factored out # At scale 1, 1/4 of the original dimensions are factored out # .... # At scale s, (1/2)^(s+1) are factored out # Hence, at backward pass of scale s, (1/2)^(s) of z should be factored in zs = utils.int_shape(z) if y is None: split = zs[3] // (2**self.scale) else: split = utils.int_shape(y)[3] new_y = z[:, :, :, -split:] z = z[:, :, :, :-split] assert (utils.int_shape(new_y)[3] == split) if y is not None: x = tf.concat([new_y, y], 3) else: x = new_y return x, z
def forward_and_jacobian(self, x, sum_log_det_jacobians, z): with tf.variable_scope(self.name): xs = utils.int_shape(x) b = self.get_mask(xs, self.mask_type) # masked half of x x1 = x * b l, m = self.function_l_m(x1, b) y = x1 + tf.multiply( -b + 1.0, x * tf.check_numerics(tf.exp(l), "exp has NaN") + m) log_det_jacobian = tf.reduce_sum(l, [1, 2, 3]) sum_log_det_jacobians += log_det_jacobian return y, sum_log_det_jacobians, z
def forward_and_jacobian(self, x, sum_log_det_jacobians, z): xs = utils.int_shape(x) split = xs[3] // 2 # The factoring out is done on the channel direction. # Haven't experimented with other ways of factoring out. new_z = x[:, :, :, :split] x = x[:, :, :, split:] if z is not None: z = tf.concat([z, new_z], 3) else: z = new_z return x, sum_log_det_jacobians, z
def get_mask(self, xs, mask_type): if 'checkerboard' in mask_type: unit0 = tf.constant([[0.0, 1.0], [1.0, 0.0]]) unit1 = -unit0 + 1.0 unit = unit0 if mask_type == 'checkerboard0' else unit1 unit = tf.reshape(unit, [1, 2, 2, 1]) b = tf.tile(unit, [xs[0], xs[1] // 2, xs[2] // 2, xs[3]]) elif 'channel' in mask_type: white = tf.ones([xs[0], xs[1], xs[2], xs[3] // 2]) black = tf.zeros([xs[0], xs[1], xs[2], xs[3] // 2]) if mask_type == 'channel0': b = tf.concat([white, black], 3) else: b = tf.concat([black, white], 3) bs = utils.int_shape(b) assert bs == xs return b