def __init__(self, nr_resnet, nr_filters, resnet_nonlinearity): super(PixelCNNLayer_down, self).__init__() self.nr_resnet = nr_resnet # stream from pixels above self.u_stream = nn.ModuleList([gated_resnet(nr_filters, down_shifted_conv2d, resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) # stream from pixels above and to thes left self.ul_stream = nn.ModuleList([gated_resnet(nr_filters, down_right_shifted_conv2d, resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet)])
def cond_pixel_cnn(x, gh=None, sh=None, nonlinearity=tf.nn.elu, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, bn=True, dropout_p=0.0, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = nn.get_name("conv_pixel_cnn", counters) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) with tf.variable_scope(name): # do not use batch normalization for auto-reggressive model, force bn to be False with arg_scope([gated_resnet], gh=gh, sh=sh, nonlinearity=nonlinearity, dropout_p=dropout_p, counters=counters): with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training): xs = nn.int_shape(x) x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on u_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left receptive_field = (2, 3) for rep in range(nr_resnet): u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) receptive_field = (receptive_field[0]+1, receptive_field[1]+2) x_out = nin(tf.nn.elu(ul_list[-1]), 10*nr_logistic_mix) print(" * receptive_field", receptive_field) return x_out
def _inference_bu(self, x, ema): kwargs = { "training": self.ph_is_training, "ema": ema, "init": self.init } stochastic_inference_bottom_up = [] deterministic_path_top_down = [] deterministic_path_bottom_up = [] skip = None d_bu = x d_td = x dim_length = len(self.deterministic_layers[0]) for i, dims in enumerate(self.deterministic_layers): assert len(dims) == dim_length # Build deterministic block for j, dim in enumerate(dims): scope = "deterministic_bottom_up_%i_%i" % (i, j) residual = False if i == 0 and j == 0 else True if i > 0: skip = tf.concat([ deterministic_path_bottom_up[i - 1], deterministic_path_top_down[i - 1] ], axis=-1) d_bu = gated_resnet(d_bu, skip, dim, self.activation, scope, residual, self.dropout_inference, **kwargs) deterministic_path_bottom_up += [d_bu] for j, dim in enumerate(dims): scope = "deterministic_top_down_%i_%i" % (i, j) residual = False if i == 0 and j == 0 else True d_td = gated_resnet(d_td, deterministic_path_bottom_up[i], dim, self.activation, scope, residual, self.dropout_inference, **kwargs) deterministic_path_top_down += [d_td] if i == len(self.deterministic_layers) - 1: break # Do not add the top z layer before building the top-down inference model. # Build stochastic layer dim = self.stochastic_layers[i] scope = "qz_bottom_up_%i" % (i + 1) q_z_bottom_up, q_mean_bottom_up, q_var_bottom_up = self._stochastic( d_bu, dim, scope, ema) stochastic_inference_bottom_up += [ (q_z_bottom_up, q_mean_bottom_up, q_var_bottom_up) ] if len(q_z_bottom_up.get_shape()) == 2: flatten_shape = [int(dim) for dim in d_bu.get_shape()[1:]] scope = "dense2conv_bottom_up_%i" % (i + 1) d_bu = dense(q_z_bottom_up, np.prod(flatten_shape), scope, **kwargs) d_bu = tf.reshape(d_bu, [-1] + flatten_shape) else: d_bu = q_z_bottom_up # Build top stochastic layer q(z_L | x) from the input of bottom-up and top-down inference. dim = self.deterministic_layers[-1][-1] d = tf.concat([d_td, d_bu], axis=-1) scope = "deterministic_top" d = gated_resnet(d, None, [dim[0], dim[1], [1, 1]], self.activation, scope, True, self.dropout_inference, **kwargs) flatten_shapes = [[int(dim) for dim in d.get_shape()[1:]]] scope = "qz_top_%i" % len(self.stochastic_layers) q_z_top, q_mean_top, q_var_top = self._stochastic( d, self.stochastic_layers[-1], scope, ema) stochastic_inference_bottom_up += [(q_z_top, q_mean_top, q_var_top)] return stochastic_inference_bottom_up, deterministic_path_top_down, flatten_shapes
def _inference_td(self, stochastic_inference_bottom_up, deterministic_path_top_down, flatten_shapes, ema): kwargs = { "training": self.ph_is_training, "ema": ema, "init": self.init } stochastic_inference_top_down = [] z_top = stochastic_inference_bottom_up[-1][0] stochastic_inference_bottom_up = stochastic_inference_bottom_up[:-1] stochastic_layers_reordered = self.stochastic_layers[::-1][1:] deterministic_path_top_down_reordered = deterministic_path_top_down[::-1][ 1:] generative_skip_connections = [] for i, dims in enumerate(self.deterministic_layers[::-1][:-1]): z_top_reshaped = z_top scope_index = len(self.stochastic_layers) - (i + 1) if len(z_top_reshaped.get_shape()) == 2: scope = "dense2conv_%i" % scope_index z_top_reshaped = dense(z_top_reshaped, np.prod(flatten_shapes[i]), scope, **kwargs) z_top_reshaped = tf.reshape(z_top_reshaped, [-1] + flatten_shapes[i]) # Build deterministic block of generative model. d_p = z_top_reshaped for j, dim in enumerate(dims[::-1]): skip = None if i > 0: skip = generative_skip_connections.pop() residual = False if j == 0 else True scope = "deterministic_generative_%i_%i" % (i, j) d_p = transposed_gated_resnet(d_p, skip, dim, self.activation, scope, residual, self.dropout_generative, **kwargs) generative_skip_connections = [d_p ] + generative_skip_connections q_z_bottom_up = stochastic_inference_bottom_up[::-1][i][0] # Build top-down stochastic layer q(z_(L-(i+1)) | z_(L-i)) and p(z_(L-(i+1)) | z_(L-i)) d_td = deterministic_path_top_down_reordered[i] scope = "qz_top_down_pz_merge_%i" % scope_index dim = dims[::-1][-1] dim = [dim[0], dim[1], [1, 1]] d_td = gated_resnet(d_td, d_p, dim, self.activation, scope, True, self.dropout_generative, **kwargs) flatten_shapes += [[int(d) for d in d_td.get_shape()[1:]]] scope = "qz_top_down_%i" % scope_index q_z_top_down, q_mean_top_down, q_var_top_down = self._stochastic( d_td, stochastic_layers_reordered[i], scope, ema) stochastic_inference_top_down += [(q_z_top_down, q_mean_top_down, q_var_top_down)] z_top = tf.concat([q_z_top_down, q_z_bottom_up], axis=-1) return stochastic_inference_top_down[::-1]