Beispiel #1
0
def build_model(x, init=False, sampling_mode=False):
    global nvp_layers
    global nvp_dense_layers
    with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init):
        if len(nvp_layers) == 0:
            build_nvp_model()

        if len(nvp_dense_layers) == 0:
            build_nvp_dense_model()

        global student_layer
        if student_layer is None:
            student_layer = nn_extra_student.StudentRecurrentLayer(
                shape=(ndim, ), corr_init=corr_init, nu_init=nu_init)

        x_shape = nn_extra_nvp.int_shape(x)
        x_bs = tf.reshape(
            x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4]))
        x_bs_shape = nn_extra_nvp.int_shape(x_bs)

        log_det_jac = tf.zeros(x_bs_shape[0])

        logit_layer = nn_extra_nvp.LogitLayer()
        scale_layer = nn_extra_nvp.ScaleLayer()

        y, log_det_jac = scale_layer.forward_and_jacobian(
            x_bs, None, log_det_jac)
        y, log_det_jac = logit_layer.forward_and_jacobian(y, None, log_det_jac)

        # construct forward pass
        z = None
        for layer in nvp_layers:
            y, z, log_det_jac = layer.forward_and_jacobian(y, z, log_det_jac)

        z = tf.concat([z, y], 3)
        for layer in nvp_dense_layers:
            z, _, log_det_jac = layer.forward_and_jacobian(
                z, None, log_det_jac)

        z_shape = nn_extra_nvp.int_shape(z)
        z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1))
        log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1]))

        log_probs = []
        z_samples = []
        latent_log_probs = []
        latent_log_probs_prior = []

        with tf.variable_scope("one_step") as scope:
            student_layer.reset()
            for i in range(seq_len):
                if sampling_mode:
                    z_sample = student_layer.sample(nr_samples=n_samples)
                    z_samples.append(z_sample)

                    latent_log_prob = student_layer.get_log_likelihood(
                        z_sample[:, 0, :])
                    latent_log_probs.append(latent_log_prob)

                else:
                    latent_log_prob = student_layer.get_log_likelihood(
                        z_vec[:, i, :])
                    latent_log_probs.append(latent_log_prob)

                    log_prob = latent_log_prob + log_det_jac[:, i]
                    log_probs.append(log_prob)

                    latent_log_prob_prior = student_layer.get_log_likelihood_under_prior(
                        z_vec[:, i, :])
                    latent_log_probs_prior.append(latent_log_prob_prior)

                student_layer.update_distribution(z_vec[:, i, :])
                scope.reuse_variables()

        if sampling_mode:
            # one more sample after seeing the last element in the sequence
            z_sample = student_layer.sample(nr_samples=n_samples)
            z_samples.append(z_sample)
            z_samples = tf.concat(z_samples, 1)

            latent_log_prob = student_layer.get_log_likelihood(z_sample[:,
                                                                        0, :])
            latent_log_probs.append(latent_log_prob)

            z_samples_shape = nn_extra_nvp.int_shape(z_samples)
            z_samples = tf.reshape(
                z_samples,
                (z_samples_shape[0] * z_samples_shape[1], z_shape[1],
                 z_shape[2], z_shape[3]))  # (n_samples*seq_len, z_img_shape)

            log_det_jac = tf.zeros(z_samples_shape[0] * z_samples_shape[1])
            for layer in reversed(nvp_dense_layers):
                z_samples, _, log_det_jac = layer.backward(
                    z_samples, None, log_det_jac)

            x_samples = None
            for layer in reversed(nvp_layers):
                x_samples, z_samples, log_det_jac = layer.backward(
                    x_samples, z_samples, log_det_jac)

            x_samples, log_det_jac = logit_layer.backward(
                x_samples, None, log_det_jac)
            x_samples, log_det_jac = scale_layer.backward(
                x_samples, None, log_det_jac)
            x_samples = tf.reshape(x_samples,
                                   (z_samples_shape[0], z_samples_shape[1],
                                    x_shape[2], x_shape[3], x_shape[4]))

            log_det_jac = tf.reshape(log_det_jac,
                                     (z_samples_shape[0], z_samples_shape[1]))
            latent_log_probs = tf.stack(latent_log_probs, axis=1)

            for i in range(seq_len + 1):
                log_prob = latent_log_probs[:, i] - log_det_jac[:, i]
                log_probs.append(log_prob)

            log_probs = tf.stack(log_probs, axis=1)

            return x_samples, log_probs

        log_probs = tf.stack(log_probs, axis=1)
        latent_log_probs = tf.stack(latent_log_probs, axis=1)
        latent_log_probs_prior = tf.stack(latent_log_probs_prior, axis=1)

        return log_probs, latent_log_probs, latent_log_probs_prior, z_vec
Beispiel #2
0
K = create_covariance_matrix(seq_len, p, x_cov, x_var)
xs = []
for i in range(batch_size):
    phi = np.tile(x_mu, (seq_len, 1)).flatten()
    x1 = rng.multivariate_normal(phi, K)
    x1 = x1.reshape((seq_len, p))
    x1 = np.float32(x1)
    xs.append(x1[None, :, :])

x = np.concatenate(xs, axis=0)
print('shape x', x.shape)

x_var_tf = tf.placeholder(tf.float32, shape=(batch_size, seq_len, p))
l_rnn = nn_extra_student.StudentRecurrentLayer(shape=(p, ),
                                               nu_init=g_nu,
                                               mu_init=g_mu,
                                               var_init=g_var,
                                               corr_init=g_corr)
l_rnn2 = nn_extra_gauss.GaussianRecurrentLayer(shape=(p, ),
                                               mu_init=g_mu,
                                               var_init=g_var,
                                               corr_init=g_corr)

probs = []
probs_gauss = []
with tf.variable_scope("one_step") as scope:
    l_rnn.reset()
    l_rnn2.reset()
    for i in range(seq_len):
        prob_i = l_rnn.get_log_likelihood(x_var_tf[:, i, :])
        probs.append(prob_i)