Exemplo n.º 1
0
def build_model(x, y_label, init=False, sampling_mode=False):
    """
    Args:
        x: shape=(config.batch_size,) + config.obs_shape)
           -> (batch_size, seq_len, 32, 32, channels=1)
        y: shape=(bs, seq_len, 2)

    This function is called with
    > model = tf.make_template('model', config.build_model)
    and creates all trainable variables

    If sampling_mode:
        return x_samples
    else:
        return log_probs, log_probs, log_probs

    """
    global nvp_layers
    global nvp_dense_layers
    # Ensures that all nn_extra_nvp.*_wn layers have init=init
    with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init):
        if len(nvp_layers) == 0:
            build_nvp_model()

        if len(nvp_dense_layers) == 0:
            build_nvp_dense_model()

        global gp_layer
        if gp_layer is None:
            gp_layer = nn_extra_gauss.GaussianRecurrentLayer(
                shape=(ndim, ), corr_init=corr_init)

        # (batch_size, seq_len, 32, 32, channels=1)
        x_shape = nn_extra_nvp.int_shape(x)
        # Reshape into (batch_size * seq_len, 32, 32, channels=1)
        x_bs = tf.reshape(
            x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4]))
        x_bs_shape = nn_extra_nvp.int_shape(x_bs)

        y_label_shape = nn_extra_nvp.int_shape(y_label)
        y_label_bs = tf.reshape(
            y_label, (y_label_shape[0] * y_label_shape[1], y_label_shape[2]))
        y_label_bs = tf.layers.dense(
            y_label_bs,
            units=32,
            activation=tf.nn.leaky_relu,
            kernel_initializer=nn_extra_nvp.Orthogonal(),
            use_bias=True,
            name='labels_layer')

        log_det_jac = tf.zeros(x_bs_shape[0])  # seq_len * batch size

        # Preprocess (scale and transform) values; shape doesn't change.
        y, log_det_jac = nn_extra_nvp.dequantization_forward_and_jacobian(
            x_bs, log_det_jac)
        y, log_det_jac = nn_extra_nvp.logit_forward_and_jacobian(
            y, log_det_jac)

        # TODO: Replace RealNVP layers with GLOW layers.
        # construct forward pass through convolutional NVP layers.
        z = None
        for layer in nvp_layers:
            y, log_det_jac, z = layer.forward_and_jacobian(y,
                                                           log_det_jac,
                                                           z,
                                                           y_label=y_label_bs)
        # y, log_det_jac, z: (64, 8, 8, 2), (64,), (64, 8, 8, 14)
        z = tf.concat([z, y], 3)  # Join the channels back
        # z: (64, 8, 8, 16)
        # Followed by 6 256-unit dense layers of alternating partitions/masks.
        for layer in nvp_dense_layers:
            z, log_det_jac, _ = layer.forward_and_jacobian(z,
                                                           log_det_jac,
                                                           None,
                                                           y_label=y_label_bs)
        # log_det_jac, z: (64,), (64, 8, 8, 16)
        z_shape = nn_extra_nvp.int_shape(z)
        # Reshape z to (batch_size, seq_len, -1)
        # (last dimension is number of dimensions in the data, HxWxC)
        z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1))
        # z_vec: (4, 16, 1024), i.e. (bs, seq, -1), i.e. (bs, seq, H*W*C=ndim)
        # The log det jacobian z_i/x_i for every i in sequence of length n.
        log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1]))
        # log_det_jac: (4, 16), i.e. (bs, seq)

        log_probs = []
        z_samples = []

        with tf.variable_scope("one_step", reuse=tf.AUTO_REUSE) as scope:
            gp_layer.reset()
            if sampling_mode:
                if n_context > 0:
                    for i in range(n_context):
                        gp_layer.update_distribution(z_vec[:, i, :])
                    for i in range(seq_len):
                        z_sample = gp_layer.sample(nr_samples=1)
                        z_samples.append(z_sample)
                else:  # Sampling mode from just prior (no context)
                    for i in range(seq_len):
                        z_sample = gp_layer.sample(nr_samples=1)
                        z_samples.append(z_sample)
                        # Update each dimension of the latent space
                        gp_layer.update_distribution(z_vec[:, i, :])
            else:  # Training mode
                if n_context > 0:  # Some of sequence are context points
                    for i in range(n_context):
                        gp_layer.update_distribution(z_vec[:, i, :])

                    for i in range(n_context, seq_len):
                        latent_log_prob = gp_layer.get_log_likelihood(
                            z_vec[:, i, :])
                        log_prob = latent_log_prob + log_det_jac[:, i]
                        log_probs.append(log_prob)
                else:  # Sampling from prior
                    for i in range(seq_len):
                        latent_log_prob = gp_layer.get_log_likelihood(
                            z_vec[:, i, :])
                        log_prob = latent_log_prob + log_det_jac[:, i]
                        log_probs.append(log_prob)
                        gp_layer.update_distribution(z_vec[:, i, :])

        if sampling_mode:
            z_samples = tf.concat(z_samples, 1)
            z_samples_shape = nn_extra_nvp.int_shape(z_samples)
            z_samples = tf.reshape(z_samples,
                                   z_shape)  # (n_samples*seq_len, z_img_shape)

            for layer in reversed(nvp_dense_layers):
                z_samples, _ = layer.backward(z_samples,
                                              None,
                                              y_label=y_label_bs)

            x_samples = None
            for layer in reversed(nvp_layers):
                x_samples, z_samples = layer.backward(x_samples,
                                                      z_samples,
                                                      y_label=y_label_bs)

            # inverse logit
            x_samples = 1. / (1 + tf.exp(-x_samples))
            x_samples = tf.reshape(x_samples,
                                   (z_samples_shape[0], z_samples_shape[1],
                                    x_shape[2], x_shape[3], x_shape[4]))
            return x_samples

        # Reshape from (N, A, B, C) to (A, N, B, C)
        # Kind of "zipping" the log probs
        log_probs = tf.stack(log_probs, axis=1)

        return log_probs, log_probs, log_probs
Exemplo n.º 2
0
def build_model(x, init=False, sampling_mode=False):
    global nvp_layers
    global nvp_dense_layers
    with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init):
        if len(nvp_layers) == 0:
            build_nvp_model()

        if len(nvp_dense_layers) == 0:
            build_nvp_dense_model()

        global student_layer
        if student_layer is None:
            student_layer = nn_extra_gauss.GaussianRecurrentLayer(
                shape=(ndim, ), corr_init=corr_init)

        x_shape = nn_extra_nvp.int_shape(x)
        x_bs = tf.reshape(
            x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4]))
        x_bs_shape = nn_extra_nvp.int_shape(x_bs)

        log_det_jac = tf.zeros(x_bs_shape[0])

        logit_layer = nn_extra_nvp.LogitLayer()
        scale_layer = nn_extra_nvp.ScaleLayer()

        y, log_det_jac = scale_layer.forward_and_jacobian(
            x_bs, None, log_det_jac)
        y, log_det_jac = logit_layer.forward_and_jacobian(y, None, log_det_jac)

        # construct forward pass
        z = None
        for layer in nvp_layers:
            y, z, log_det_jac = layer.forward_and_jacobian(y, z, log_det_jac)

        z = tf.concat([z, y], 3)
        for layer in nvp_dense_layers:
            z, _, log_det_jac = layer.forward_and_jacobian(
                z, None, log_det_jac)

        z_shape = nn_extra_nvp.int_shape(z)
        z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1))
        log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1]))

        log_probs = []
        z_samples = []
        latent_log_probs = []
        latent_log_probs_prior = []

        with tf.variable_scope("one_step") as scope:
            student_layer.reset()
            for i in range(seq_len):
                if sampling_mode:
                    z_sample = student_layer.sample(nr_samples=n_samples)
                    z_samples.append(z_sample)

                    latent_log_prob = student_layer.get_log_likelihood(
                        z_sample[:, 0, :])
                    latent_log_probs.append(latent_log_prob)

                else:
                    latent_log_prob = student_layer.get_log_likelihood(
                        z_vec[:, i, :])
                    latent_log_probs.append(latent_log_prob)

                    log_prob = latent_log_prob + log_det_jac[:, i]
                    log_probs.append(log_prob)

                    latent_log_prob_prior = student_layer.get_log_likelihood_under_prior(
                        z_vec[:, i, :])
                    latent_log_probs_prior.append(latent_log_prob_prior)

                student_layer.update_distribution(z_vec[:, i, :])
                scope.reuse_variables()

        if sampling_mode:
            # one more sample after seeing the last element in the sequence
            z_sample = student_layer.sample(nr_samples=n_samples)
            z_samples.append(z_sample)
            z_samples = tf.concat(z_samples, 1)

            latent_log_prob = student_layer.get_log_likelihood(z_sample[:,
                                                                        0, :])
            latent_log_probs.append(latent_log_prob)

            z_samples_shape = nn_extra_nvp.int_shape(z_samples)
            z_samples = tf.reshape(
                z_samples,
                (z_samples_shape[0] * z_samples_shape[1], z_shape[1],
                 z_shape[2], z_shape[3]))  # (n_samples*seq_len, z_img_shape)

            log_det_jac = tf.zeros(z_samples_shape[0] * z_samples_shape[1])
            for layer in reversed(nvp_dense_layers):
                z_samples, _, log_det_jac = layer.backward(
                    z_samples, None, log_det_jac)

            x_samples = None
            for layer in reversed(nvp_layers):
                x_samples, z_samples, log_det_jac = layer.backward(
                    x_samples, z_samples, log_det_jac)

            x_samples, log_det_jac = logit_layer.backward(
                x_samples, None, log_det_jac)
            x_samples, log_det_jac = scale_layer.backward(
                x_samples, None, log_det_jac)
            x_samples = tf.reshape(x_samples,
                                   (z_samples_shape[0], z_samples_shape[1],
                                    x_shape[2], x_shape[3], x_shape[4]))

            log_det_jac = tf.reshape(log_det_jac,
                                     (z_samples_shape[0], z_samples_shape[1]))
            latent_log_probs = tf.stack(latent_log_probs, axis=1)

            for i in range(seq_len + 1):
                log_prob = latent_log_probs[:, i] - log_det_jac[:, i]
                log_probs.append(log_prob)

            log_probs = tf.stack(log_probs, axis=1)

            return x_samples, log_probs

        log_probs = tf.stack(log_probs, axis=1)
        latent_log_probs = tf.stack(latent_log_probs, axis=1)
        latent_log_probs_prior = tf.stack(latent_log_probs_prior, axis=1)

        return log_probs, latent_log_probs, latent_log_probs_prior, z_vec
Exemplo n.º 3
0
def build_model(x,
                y_label,
                init=False,
                sampling_mode=False,
                glow_width=8,
                glow_num_steps=32,
                glow_num_scales=4):
    """
    Args:
        x: float32 placeholder with shape=(config.batch_size,) + config.obs_shape)
           -> (batch_size, seq_len, 32, 32, channels=1)
        y: shape=(bs, seq_len, 2)

    This function is called with
    > model = tf.make_template('model', config.build_model)
    and creates all trainable variables

    If sampling_mode:
        return x_samples
    else:
        return log_probs, log_probs, log_probs

    """
    # Ensures that all nn_extra_nvp.*_wn layers have init=init
    with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init):
        if glow_model is None:
            build_glow_model(glow_width, glow_num_steps, glow_num_scales)

        global gp_layer
        if gp_layer is None:
            gp_layer = nn_extra_gauss.GaussianRecurrentLayer(
                shape=(ndim, ), corr_init=corr_init)

        # (batch_size, seq_len, 32, 32, channels=1)
        x_shape = K.int_shape(x)
        x_bs = tf.reshape(
            x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4]))
        x_bs_shape = K.int_shape(x_bs)

        y_label_shape = K.int_shape(y_label)
        y_label_bs = tf.reshape(
            y_label, (y_label_shape[0] * y_label_shape[1], y_label_shape[2]))
        # Extract features of conditioning h (y labels).
        y_label_bs = tf.layers.dense(
            y_label_bs,
            units=32,
            activation=tf.nn.leaky_relu,
            kernel_initializer=nn_extra_nvp.Orthogonal(),
            use_bias=True,
            name='labels_layer')

        log_det_jac = tf.zeros(x_bs_shape[0])
        # GLOW doesn't do any pretransformation from jittering
        # but maybe we might still need to do the scaling.
        x_bs, log_det_jac = nn_extra_nvp.dequantization_forward_and_jacobian(
            x_bs, log_det_jac)

        # TODO: Replace RealNVP layers with GLOW layers.
        # construct forward pass through convolutional NVP layers.
        z = None
        # This is not having a sequence length... wait do we individually pass each image through its own normalizing flow?
        # Now we have batch size * seq len images in x_bs
        input_flow = fl.InputLayer(x_bs, y_label_bs)
        # Forward flow
        output_flow = glow_model(input_flow, forward=True)
        x, log_det_jac, z, y_label = output_flow
        #  x=[64, 2, 2, 16]	z=[64, 2, 2, 240]	logdet=[64]
        z = tf.concat([z, x], 3)  # Join the split channels back
        # [64, 2, 2, 256]
        z_shape = K.int_shape(z)
        # Reshape z to (batch_size, seq_len, -1)
        # (last dimension is probably number of dimensions in the data, HxWxC)
        z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1))  # 4, 16, 1024
        # The log det jacobian z_i/x_i for every i in sequence of length n.
        log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1]))

        log_probs = []
        z_samples = []

        with tf.variable_scope("one_step", reuse=tf.AUTO_REUSE) as scope:
            gp_layer.reset()
            if sampling_mode:
                if n_context > 0:
                    for i in range(n_context):
                        gp_layer.update_distribution(z_vec[:, i, :])
                    for i in range(seq_len):
                        z_sample = gp_layer.sample(nr_samples=1)
                        z_samples.append(z_sample)
                else:  # Sampling mode from just prior (no context)
                    for i in range(seq_len):
                        z_sample = gp_layer.sample(nr_samples=1)
                        z_samples.append(z_sample)
                        # Update each dimension of the latent space
                        # 64, 1, 480
                        gp_layer.update_distribution(z_vec[:, i, :])
            else:  # Training mode
                if n_context > 0:  # Some of sequence are context points
                    for i in range(n_context):
                        gp_layer.update_distribution(z_vec[:, i, :])

                    for i in range(n_context, seq_len):
                        latent_log_prob = gp_layer.get_log_likelihood(
                            z_vec[:, i, :])
                        log_prob = latent_log_prob + log_det_jac[:, i]
                        log_probs.append(log_prob)
                else:  # Sampling from prior
                    for i in range(seq_len):
                        latent_log_prob = gp_layer.get_log_likelihood(
                            z_vec[:, i, :])
                        log_prob = latent_log_prob + log_det_jac[:, i]
                        log_probs.append(log_prob)
                        gp_layer.update_distribution(z_vec[:, i, :])

        if sampling_mode:
            z_samples = tf.concat(z_samples, 1)
            z_samples_shape = K.int_shape(z_samples)
            z_samples = tf.reshape(z_samples,
                                   z_shape)  # (n_samples*seq_len, z_img_shape)
            x, log_det_jac, z, y_label = output_flow
            output_x_shape = K.int_shape(x)

            split = output_x_shape[3]  # Channels in output flow's x portion
            inverse_y, inverse_z = z_samples[:, :, :, :
                                             split], z_samples[:, :, :, split:]
            log_det_tmp = tf.zeros_like(log_det_jac)
            inverse_flow = inverse_y, log_det_tmp, inverse_z, y_label_bs

            print(f"Reconstructing model with inverse flow {inverse_flow}")
            reconstruction = glow_model(inverse_flow, forward=False)
            x_samples = reconstruction[0]

            x_samples = tf.reshape(x_samples,
                                   (z_samples_shape[0], z_samples_shape[1],
                                    x_shape[2], x_shape[3], x_shape[4]))
            return x_samples

        # Reshape from (N, A, B, C) to (A, N, B, C)
        # Kind of "zipping" the log probs
        log_probs = tf.stack(log_probs, axis=1)
        # log probs: (4, 15)
        return log_probs, log_probs, log_probs
Exemplo n.º 4
0
    x1 = rng.multivariate_normal(phi, K)
    x1 = x1.reshape((seq_len, p))
    x1 = np.float32(x1)
    xs.append(x1[None, :, :])

x = np.concatenate(xs, axis=0)
print('shape x', x.shape)

x_var_tf = tf.placeholder(tf.float32, shape=(batch_size, seq_len, p))
l_rnn = nn_extra_student.StudentRecurrentLayer(shape=(p, ),
                                               nu_init=g_nu,
                                               mu_init=g_mu,
                                               var_init=g_var,
                                               corr_init=g_corr)
l_rnn2 = nn_extra_gauss.GaussianRecurrentLayer(shape=(p, ),
                                               mu_init=g_mu,
                                               var_init=g_var,
                                               corr_init=g_corr)

probs = []
probs_gauss = []
with tf.variable_scope("one_step") as scope:
    l_rnn.reset()
    l_rnn2.reset()
    for i in range(seq_len):
        prob_i = l_rnn.get_log_likelihood(x_var_tf[:, i, :])
        probs.append(prob_i)
        l_rnn.update_distribution(x_var_tf[:, i, :])

        prob_i = l_rnn2.get_log_likelihood(x_var_tf[:, i, :])
        probs_gauss.append(prob_i)
        l_rnn2.update_distribution(x_var_tf[:, i, :])
Exemplo n.º 5
0
def build_model(x, y_label, init=False, sampling_mode=False):
    global nvp_layers
    global nvp_dense_layers
    with arg_scope([nn_extra_nvp.conv2d_wn, nn_extra_nvp.dense_wn], init=init):
        if len(nvp_layers) == 0:
            build_nvp_model()

        if len(nvp_dense_layers) == 0:
            build_nvp_dense_model()

        global student_layer
        if student_layer is None:
            student_layer = nn_extra_gauss.GaussianRecurrentLayer(
                shape=(ndim, ), corr_init=corr_init)

        x_shape = nn_extra_nvp.int_shape(x)
        x_bs = tf.reshape(
            x, (x_shape[0] * x_shape[1], x_shape[2], x_shape[3], x_shape[4]))
        x_bs_shape = nn_extra_nvp.int_shape(x_bs)

        y_label_shape = nn_extra_nvp.int_shape(y_label)
        y_label_bs = tf.reshape(
            y_label, (y_label_shape[0] * y_label_shape[1], y_label_shape[2]))
        y_label_bs = tf.layers.dense(
            y_label_bs,
            units=32,
            activation=tf.nn.leaky_relu,
            kernel_initializer=nn_extra_nvp.Orthogonal(),
            use_bias=True,
            name='labels_layer')

        log_det_jac = tf.zeros(x_bs_shape[0])

        y, log_det_jac = nn_extra_nvp.dequantization_forward_and_jacobian(
            x_bs, log_det_jac)
        y, log_det_jac = nn_extra_nvp.logit_forward_and_jacobian(
            y, log_det_jac)

        # construct forward pass
        z = None
        for layer in nvp_layers:
            y, log_det_jac, z = layer.forward_and_jacobian(y,
                                                           log_det_jac,
                                                           z,
                                                           y_label=y_label_bs)

        z = tf.concat([z, y], 3)
        for layer in nvp_dense_layers:
            z, log_det_jac, _ = layer.forward_and_jacobian(z,
                                                           log_det_jac,
                                                           None,
                                                           y_label=y_label_bs)

        z_shape = nn_extra_nvp.int_shape(z)
        z_vec = tf.reshape(z, (x_shape[0], x_shape[1], -1))
        log_det_jac = tf.reshape(log_det_jac, (x_shape[0], x_shape[1]))

        log_probs = []
        z_samples = []

        with tf.variable_scope("one_step", reuse=tf.AUTO_REUSE) as scope:
            student_layer.reset()
            if sampling_mode:
                if n_context > 0:
                    for i in range(n_context):
                        student_layer.update_distribution(z_vec[:, i, :])
                    for i in range(seq_len):
                        z_sample = student_layer.sample(nr_samples=1)
                        z_samples.append(z_sample)
                else:
                    for i in range(seq_len):
                        z_sample = student_layer.sample(nr_samples=1)
                        z_samples.append(z_sample)
                        student_layer.update_distribution(z_vec[:, i, :])
            else:
                if n_context > 0:
                    for i in range(n_context):
                        student_layer.update_distribution(z_vec[:, i, :])

                    for i in range(n_context, seq_len):
                        latent_log_prob = student_layer.get_log_likelihood(
                            z_vec[:, i, :])
                        log_prob = latent_log_prob + log_det_jac[:, i]
                        log_probs.append(log_prob)
                else:
                    for i in range(seq_len):
                        latent_log_prob = student_layer.get_log_likelihood(
                            z_vec[:, i, :])
                        log_prob = latent_log_prob + log_det_jac[:, i]
                        log_probs.append(log_prob)
                        student_layer.update_distribution(z_vec[:, i, :])

        if sampling_mode:
            z_samples = tf.concat(z_samples, 1)
            z_samples_shape = nn_extra_nvp.int_shape(z_samples)
            z_samples = tf.reshape(z_samples,
                                   z_shape)  # (n_samples*seq_len, z_img_shape)

            for layer in reversed(nvp_dense_layers):
                z_samples, _ = layer.backward(z_samples,
                                              None,
                                              y_label=y_label_bs)

            x_samples = None
            for layer in reversed(nvp_layers):
                x_samples, z_samples = layer.backward(x_samples,
                                                      z_samples,
                                                      y_label=y_label_bs)

            # inverse logit
            x_samples = 1. / (1 + tf.exp(-x_samples))
            x_samples = tf.reshape(x_samples,
                                   (z_samples_shape[0], z_samples_shape[1],
                                    x_shape[2], x_shape[3], x_shape[4]))
            return x_samples

        log_probs = tf.stack(log_probs, axis=1)

        return log_probs, log_probs, log_probs