def build_model(x, y_label, init=False, sampling_mode=False): """ Args: x: shape=(config.batch_size,) + config.obs_shape) -> (batch_size, seq_len, 32, 32, channels=1) y: shape=(bs, seq_len, 2) This function is called with > model = tf.make_template('model', config.build_model) and creates all trainable variables If sampling_mode: return x_samples else: return log_probs, log_probs, log_probs """ global nvp_layers global nvp_dense_layers # Ensures that all nn_extra_nvp.*_wn layers have init=init with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init): if len(nvp_layers) == 0: build_nvp_model() if len(nvp_dense_layers) == 0: build_nvp_dense_model() global gp_layer if gp_layer is None: gp_layer = nn_extra_gauss.GaussianRecurrentLayer( shape=(ndim, ), corr_init=corr_init) # (batch_size, seq_len, 32, 32, channels=1) x_shape = nn_extra_nvp.int_shape(x) # Reshape into (batch_size * seq_len, 32, 32, channels=1) x_bs = tf.reshape( x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4])) x_bs_shape = nn_extra_nvp.int_shape(x_bs) y_label_shape = nn_extra_nvp.int_shape(y_label) y_label_bs = tf.reshape( y_label, (y_label_shape[0] * y_label_shape[1], y_label_shape[2])) y_label_bs = tf.layers.dense( y_label_bs, units=32, activation=tf.nn.leaky_relu, kernel_initializer=nn_extra_nvp.Orthogonal(), use_bias=True, name='labels_layer') log_det_jac = tf.zeros(x_bs_shape[0]) # seq_len * batch size # Preprocess (scale and transform) values; shape doesn't change. y, log_det_jac = nn_extra_nvp.dequantization_forward_and_jacobian( x_bs, log_det_jac) y, log_det_jac = nn_extra_nvp.logit_forward_and_jacobian( y, log_det_jac) # TODO: Replace RealNVP layers with GLOW layers. # construct forward pass through convolutional NVP layers. z = None for layer in nvp_layers: y, log_det_jac, z = layer.forward_and_jacobian(y, log_det_jac, z, y_label=y_label_bs) # y, log_det_jac, z: (64, 8, 8, 2), (64,), (64, 8, 8, 14) z = tf.concat([z, y], 3) # Join the channels back # z: (64, 8, 8, 16) # Followed by 6 256-unit dense layers of alternating partitions/masks. for layer in nvp_dense_layers: z, log_det_jac, _ = layer.forward_and_jacobian(z, log_det_jac, None, y_label=y_label_bs) # log_det_jac, z: (64,), (64, 8, 8, 16) z_shape = nn_extra_nvp.int_shape(z) # Reshape z to (batch_size, seq_len, -1) # (last dimension is number of dimensions in the data, HxWxC) z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1)) # z_vec: (4, 16, 1024), i.e. (bs, seq, -1), i.e. (bs, seq, H*W*C=ndim) # The log det jacobian z_i/x_i for every i in sequence of length n. log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1])) # log_det_jac: (4, 16), i.e. (bs, seq) log_probs = [] z_samples = [] with tf.variable_scope("one_step", reuse=tf.AUTO_REUSE) as scope: gp_layer.reset() if sampling_mode: if n_context > 0: for i in range(n_context): gp_layer.update_distribution(z_vec[:, i, :]) for i in range(seq_len): z_sample = gp_layer.sample(nr_samples=1) z_samples.append(z_sample) else: # Sampling mode from just prior (no context) for i in range(seq_len): z_sample = gp_layer.sample(nr_samples=1) z_samples.append(z_sample) # Update each dimension of the latent space gp_layer.update_distribution(z_vec[:, i, :]) else: # Training mode if n_context > 0: # Some of sequence are context points for i in range(n_context): gp_layer.update_distribution(z_vec[:, i, :]) for i in range(n_context, seq_len): latent_log_prob = gp_layer.get_log_likelihood( z_vec[:, i, :]) log_prob = latent_log_prob + log_det_jac[:, i] log_probs.append(log_prob) else: # Sampling from prior for i in range(seq_len): latent_log_prob = gp_layer.get_log_likelihood( z_vec[:, i, :]) log_prob = latent_log_prob + log_det_jac[:, i] log_probs.append(log_prob) gp_layer.update_distribution(z_vec[:, i, :]) if sampling_mode: z_samples = tf.concat(z_samples, 1) z_samples_shape = nn_extra_nvp.int_shape(z_samples) z_samples = tf.reshape(z_samples, z_shape) # (n_samples*seq_len, z_img_shape) for layer in reversed(nvp_dense_layers): z_samples, _ = layer.backward(z_samples, None, y_label=y_label_bs) x_samples = None for layer in reversed(nvp_layers): x_samples, z_samples = layer.backward(x_samples, z_samples, y_label=y_label_bs) # inverse logit x_samples = 1. / (1 + tf.exp(-x_samples)) x_samples = tf.reshape(x_samples, (z_samples_shape[0], z_samples_shape[1], x_shape[2], x_shape[3], x_shape[4])) return x_samples # Reshape from (N, A, B, C) to (A, N, B, C) # Kind of "zipping" the log probs log_probs = tf.stack(log_probs, axis=1) return log_probs, log_probs, log_probs
def build_model(x, init=False, sampling_mode=False): global nvp_layers global nvp_dense_layers with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init): if len(nvp_layers) == 0: build_nvp_model() if len(nvp_dense_layers) == 0: build_nvp_dense_model() global student_layer if student_layer is None: student_layer = nn_extra_gauss.GaussianRecurrentLayer( shape=(ndim, ), corr_init=corr_init) x_shape = nn_extra_nvp.int_shape(x) x_bs = tf.reshape( x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4])) x_bs_shape = nn_extra_nvp.int_shape(x_bs) log_det_jac = tf.zeros(x_bs_shape[0]) logit_layer = nn_extra_nvp.LogitLayer() scale_layer = nn_extra_nvp.ScaleLayer() y, log_det_jac = scale_layer.forward_and_jacobian( x_bs, None, log_det_jac) y, log_det_jac = logit_layer.forward_and_jacobian(y, None, log_det_jac) # construct forward pass z = None for layer in nvp_layers: y, z, log_det_jac = layer.forward_and_jacobian(y, z, log_det_jac) z = tf.concat([z, y], 3) for layer in nvp_dense_layers: z, _, log_det_jac = layer.forward_and_jacobian( z, None, log_det_jac) z_shape = nn_extra_nvp.int_shape(z) z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1)) log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1])) log_probs = [] z_samples = [] latent_log_probs = [] latent_log_probs_prior = [] with tf.variable_scope("one_step") as scope: student_layer.reset() for i in range(seq_len): if sampling_mode: z_sample = student_layer.sample(nr_samples=n_samples) z_samples.append(z_sample) latent_log_prob = student_layer.get_log_likelihood( z_sample[:, 0, :]) latent_log_probs.append(latent_log_prob) else: latent_log_prob = student_layer.get_log_likelihood( z_vec[:, i, :]) latent_log_probs.append(latent_log_prob) log_prob = latent_log_prob + log_det_jac[:, i] log_probs.append(log_prob) latent_log_prob_prior = student_layer.get_log_likelihood_under_prior( z_vec[:, i, :]) latent_log_probs_prior.append(latent_log_prob_prior) student_layer.update_distribution(z_vec[:, i, :]) scope.reuse_variables() if sampling_mode: # one more sample after seeing the last element in the sequence z_sample = student_layer.sample(nr_samples=n_samples) z_samples.append(z_sample) z_samples = tf.concat(z_samples, 1) latent_log_prob = student_layer.get_log_likelihood(z_sample[:, 0, :]) latent_log_probs.append(latent_log_prob) z_samples_shape = nn_extra_nvp.int_shape(z_samples) z_samples = tf.reshape( z_samples, (z_samples_shape[0] * z_samples_shape[1], z_shape[1], z_shape[2], z_shape[3])) # (n_samples*seq_len, z_img_shape) log_det_jac = tf.zeros(z_samples_shape[0] * z_samples_shape[1]) for layer in reversed(nvp_dense_layers): z_samples, _, log_det_jac = layer.backward( z_samples, None, log_det_jac) x_samples = None for layer in reversed(nvp_layers): x_samples, z_samples, log_det_jac = layer.backward( x_samples, z_samples, log_det_jac) x_samples, log_det_jac = logit_layer.backward( x_samples, None, log_det_jac) x_samples, log_det_jac = scale_layer.backward( x_samples, None, log_det_jac) x_samples = tf.reshape(x_samples, (z_samples_shape[0], z_samples_shape[1], x_shape[2], x_shape[3], x_shape[4])) log_det_jac = tf.reshape(log_det_jac, (z_samples_shape[0], z_samples_shape[1])) latent_log_probs = tf.stack(latent_log_probs, axis=1) for i in range(seq_len + 1): log_prob = latent_log_probs[:, i] - log_det_jac[:, i] log_probs.append(log_prob) log_probs = tf.stack(log_probs, axis=1) return x_samples, log_probs log_probs = tf.stack(log_probs, axis=1) latent_log_probs = tf.stack(latent_log_probs, axis=1) latent_log_probs_prior = tf.stack(latent_log_probs_prior, axis=1) return log_probs, latent_log_probs, latent_log_probs_prior, z_vec
def build_model(x, y_label, init=False, sampling_mode=False, glow_width=8, glow_num_steps=32, glow_num_scales=4): """ Args: x: float32 placeholder with shape=(config.batch_size,) + config.obs_shape) -> (batch_size, seq_len, 32, 32, channels=1) y: shape=(bs, seq_len, 2) This function is called with > model = tf.make_template('model', config.build_model) and creates all trainable variables If sampling_mode: return x_samples else: return log_probs, log_probs, log_probs """ # Ensures that all nn_extra_nvp.*_wn layers have init=init with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init): if glow_model is None: build_glow_model(glow_width, glow_num_steps, glow_num_scales) global gp_layer if gp_layer is None: gp_layer = nn_extra_gauss.GaussianRecurrentLayer( shape=(ndim, ), corr_init=corr_init) # (batch_size, seq_len, 32, 32, channels=1) x_shape = K.int_shape(x) x_bs = tf.reshape( x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4])) x_bs_shape = K.int_shape(x_bs) y_label_shape = K.int_shape(y_label) y_label_bs = tf.reshape( y_label, (y_label_shape[0] * y_label_shape[1], y_label_shape[2])) # Extract features of conditioning h (y labels). y_label_bs = tf.layers.dense( y_label_bs, units=32, activation=tf.nn.leaky_relu, kernel_initializer=nn_extra_nvp.Orthogonal(), use_bias=True, name='labels_layer') log_det_jac = tf.zeros(x_bs_shape[0]) # GLOW doesn't do any pretransformation from jittering # but maybe we might still need to do the scaling. x_bs, log_det_jac = nn_extra_nvp.dequantization_forward_and_jacobian( x_bs, log_det_jac) # TODO: Replace RealNVP layers with GLOW layers. # construct forward pass through convolutional NVP layers. z = None # This is not having a sequence length... wait do we individually pass each image through its own normalizing flow? # Now we have batch size * seq len images in x_bs input_flow = fl.InputLayer(x_bs, y_label_bs) # Forward flow output_flow = glow_model(input_flow, forward=True) x, log_det_jac, z, y_label = output_flow # x=[64, 2, 2, 16] z=[64, 2, 2, 240] logdet=[64] z = tf.concat([z, x], 3) # Join the split channels back # [64, 2, 2, 256] z_shape = K.int_shape(z) # Reshape z to (batch_size, seq_len, -1) # (last dimension is probably number of dimensions in the data, HxWxC) z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1)) # 4, 16, 1024 # The log det jacobian z_i/x_i for every i in sequence of length n. log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1])) log_probs = [] z_samples = [] with tf.variable_scope("one_step", reuse=tf.AUTO_REUSE) as scope: gp_layer.reset() if sampling_mode: if n_context > 0: for i in range(n_context): gp_layer.update_distribution(z_vec[:, i, :]) for i in range(seq_len): z_sample = gp_layer.sample(nr_samples=1) z_samples.append(z_sample) else: # Sampling mode from just prior (no context) for i in range(seq_len): z_sample = gp_layer.sample(nr_samples=1) z_samples.append(z_sample) # Update each dimension of the latent space # 64, 1, 480 gp_layer.update_distribution(z_vec[:, i, :]) else: # Training mode if n_context > 0: # Some of sequence are context points for i in range(n_context): gp_layer.update_distribution(z_vec[:, i, :]) for i in range(n_context, seq_len): latent_log_prob = gp_layer.get_log_likelihood( z_vec[:, i, :]) log_prob = latent_log_prob + log_det_jac[:, i] log_probs.append(log_prob) else: # Sampling from prior for i in range(seq_len): latent_log_prob = gp_layer.get_log_likelihood( z_vec[:, i, :]) log_prob = latent_log_prob + log_det_jac[:, i] log_probs.append(log_prob) gp_layer.update_distribution(z_vec[:, i, :]) if sampling_mode: z_samples = tf.concat(z_samples, 1) z_samples_shape = K.int_shape(z_samples) z_samples = tf.reshape(z_samples, z_shape) # (n_samples*seq_len, z_img_shape) x, log_det_jac, z, y_label = output_flow output_x_shape = K.int_shape(x) split = output_x_shape[3] # Channels in output flow's x portion inverse_y, inverse_z = z_samples[:, :, :, : split], z_samples[:, :, :, split:] log_det_tmp = tf.zeros_like(log_det_jac) inverse_flow = inverse_y, log_det_tmp, inverse_z, y_label_bs print(f"Reconstructing model with inverse flow {inverse_flow}") reconstruction = glow_model(inverse_flow, forward=False) x_samples = reconstruction[0] x_samples = tf.reshape(x_samples, (z_samples_shape[0], z_samples_shape[1], x_shape[2], x_shape[3], x_shape[4])) return x_samples # Reshape from (N, A, B, C) to (A, N, B, C) # Kind of "zipping" the log probs log_probs = tf.stack(log_probs, axis=1) # log probs: (4, 15) return log_probs, log_probs, log_probs
x1 = rng.multivariate_normal(phi, K) x1 = x1.reshape((seq_len, p)) x1 = np.float32(x1) xs.append(x1[None, :, :]) x = np.concatenate(xs, axis=0) print('shape x', x.shape) x_var_tf = tf.placeholder(tf.float32, shape=(batch_size, seq_len, p)) l_rnn = nn_extra_student.StudentRecurrentLayer(shape=(p, ), nu_init=g_nu, mu_init=g_mu, var_init=g_var, corr_init=g_corr) l_rnn2 = nn_extra_gauss.GaussianRecurrentLayer(shape=(p, ), mu_init=g_mu, var_init=g_var, corr_init=g_corr) probs = [] probs_gauss = [] with tf.variable_scope("one_step") as scope: l_rnn.reset() l_rnn2.reset() for i in range(seq_len): prob_i = l_rnn.get_log_likelihood(x_var_tf[:, i, :]) probs.append(prob_i) l_rnn.update_distribution(x_var_tf[:, i, :]) prob_i = l_rnn2.get_log_likelihood(x_var_tf[:, i, :]) probs_gauss.append(prob_i) l_rnn2.update_distribution(x_var_tf[:, i, :])
def build_model(x, y_label, init=False, sampling_mode=False): global nvp_layers global nvp_dense_layers with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init): if len(nvp_layers) == 0: build_nvp_model() if len(nvp_dense_layers) == 0: build_nvp_dense_model() global student_layer if student_layer is None: student_layer = nn_extra_gauss.GaussianRecurrentLayer( shape=(ndim, ), corr_init=corr_init) x_shape = nn_extra_nvp.int_shape(x) x_bs = tf.reshape( x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4])) x_bs_shape = nn_extra_nvp.int_shape(x_bs) y_label_shape = nn_extra_nvp.int_shape(y_label) y_label_bs = tf.reshape( y_label, (y_label_shape[0] * y_label_shape[1], y_label_shape[2])) y_label_bs = tf.layers.dense( y_label_bs, units=32, activation=tf.nn.leaky_relu, kernel_initializer=nn_extra_nvp.Orthogonal(), use_bias=True, name='labels_layer') log_det_jac = tf.zeros(x_bs_shape[0]) y, log_det_jac = nn_extra_nvp.dequantization_forward_and_jacobian( x_bs, log_det_jac) y, log_det_jac = nn_extra_nvp.logit_forward_and_jacobian( y, log_det_jac) # construct forward pass z = None for layer in nvp_layers: y, log_det_jac, z = layer.forward_and_jacobian(y, log_det_jac, z, y_label=y_label_bs) z = tf.concat([z, y], 3) for layer in nvp_dense_layers: z, log_det_jac, _ = layer.forward_and_jacobian(z, log_det_jac, None, y_label=y_label_bs) z_shape = nn_extra_nvp.int_shape(z) z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1)) log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1])) log_probs = [] z_samples = [] with tf.variable_scope("one_step", reuse=tf.AUTO_REUSE) as scope: student_layer.reset() if sampling_mode: if n_context > 0: for i in range(n_context): student_layer.update_distribution(z_vec[:, i, :]) for i in range(seq_len): z_sample = student_layer.sample(nr_samples=1) z_samples.append(z_sample) else: for i in range(seq_len): z_sample = student_layer.sample(nr_samples=1) z_samples.append(z_sample) student_layer.update_distribution(z_vec[:, i, :]) else: if n_context > 0: for i in range(n_context): student_layer.update_distribution(z_vec[:, i, :]) for i in range(n_context, seq_len): latent_log_prob = student_layer.get_log_likelihood( z_vec[:, i, :]) log_prob = latent_log_prob + log_det_jac[:, i] log_probs.append(log_prob) else: for i in range(seq_len): latent_log_prob = student_layer.get_log_likelihood( z_vec[:, i, :]) log_prob = latent_log_prob + log_det_jac[:, i] log_probs.append(log_prob) student_layer.update_distribution(z_vec[:, i, :]) if sampling_mode: z_samples = tf.concat(z_samples, 1) z_samples_shape = nn_extra_nvp.int_shape(z_samples) z_samples = tf.reshape(z_samples, z_shape) # (n_samples*seq_len, z_img_shape) for layer in reversed(nvp_dense_layers): z_samples, _ = layer.backward(z_samples, None, y_label=y_label_bs) x_samples = None for layer in reversed(nvp_layers): x_samples, z_samples = layer.backward(x_samples, z_samples, y_label=y_label_bs) # inverse logit x_samples = 1. / (1 + tf.exp(-x_samples)) x_samples = tf.reshape(x_samples, (z_samples_shape[0], z_samples_shape[1], x_shape[2], x_shape[3], x_shape[4])) return x_samples log_probs = tf.stack(log_probs, axis=1) return log_probs, log_probs, log_probs