Exemplo n.º 1
0
def checkpoint(z, logdet):
    zshape = Z.int_shape(z)
    z = tf.reshape(z, [-1, zshape[1]*zshape[2]*zshape[3]])
    logdet = tf.reshape(logdet, [-1, 1])
    combined = tf.concat([z, logdet], axis=1)
    tf.add_to_collection('checkpoints', combined)
    logdet = combined[:, -1]
    z = tf.reshape(combined[:, :-1], [-1, zshape[1], zshape[2], zshape[3]])
    return z, logdet
Exemplo n.º 2
0
def split2d(name, z, objective=0.):
    with tf.variable_scope(name):
        n_z = Z.int_shape(z)[3]
        z1 = z[:, :, :, :n_z // 2]
        z2 = z[:, :, :, n_z // 2:]
        pz = split2d_prior(z1)
        objective += pz.logp(z2)
        z1 = Z.squeeze2d(z1)
        eps = pz.get_eps(z2)
        return z1, objective, eps
Exemplo n.º 3
0
    def _f_loss(x, y, is_training, reuse=False):

        with tf.variable_scope('model', reuse=reuse):
            y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')

            # Discrete -> Continuous
            objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
            z = preprocess(x)
            z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins)
            objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])

            # Encode
            z = Z.squeeze2d(z, 2)  # > 16x16x12
            z, objective, _ = encoder(z, objective)

            # Prior
            hps.top_shape = Z.int_shape(z)[1:]
            logp, _, _ = prior("prior", y_onehot, hps)
            objective += logp(z)

            # Generative loss
            nobj = - objective
            bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int(
                x.get_shape()[2]) * int(x.get_shape()[3]))  # bits per subpixel

            # Predictive loss
            if hps.weight_y > 0 and hps.ycond:

                # Classification loss
                h_y = tf.reduce_mean(z, axis=[1, 2])
                y_logits = Z.linear_zeros("classifier", h_y, hps.n_y)
                bits_y = tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=y_onehot, logits=y_logits) / np.log(2.)

                # Classification accuracy
                y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32)
                classification_error = 1 - \
                    tf.cast(tf.equal(y_predicted, y), tf.float32)
            else:
                bits_y = tf.zeros_like(bits_x)
                classification_error = tf.ones_like(bits_x)

        return bits_x, bits_y, classification_error
Exemplo n.º 4
0
        def f_encode(x, y, reuse=True):
            with tf.variable_scope('model', reuse=reuse):
                y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')

                # Discrete -> Continuous
                objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
                z = preprocess(x)
                z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins)
                objective += -np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])

                # Encode
                z = Z.squeeze2d(z, 2)  # > 16x16x12
                z, objective, eps = encoder(z, objective)

                # Prior
                hps.top_shape = Z.int_shape(z)[1:]
                logp, _, _eps = prior("prior", y_onehot, hps)
                objective += logp(z)
                eps.append(_eps(z))

            return eps
Exemplo n.º 5
0
        def f_encode(x, y, reuse=True):
            with tf.variable_scope('model', reuse=reuse):
                y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')

                # Discrete -> Continuous
                objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
                z = preprocess(x)
                z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins)
                objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])

                # Encode
                z = Z.squeeze2d(z, 2)  # > 16x16x12
                z, objective, eps = encoder(z, objective)

                # Prior
                hps.top_shape = Z.int_shape(z)[1:]
                logp, _, _eps = prior("prior", y_onehot, hps)
                objective += logp(z)
                eps.append(_eps(z))

            return eps
Exemplo n.º 6
0
def split3d(name, level, z, y_onehot, z_prior=None, objective=0.):
    with tf.variable_scope(name + str(level)):
        n_z = Z.int_shape(z)[4]
        z1 = z[:, :, :, :, :n_z // 2]
        z2 = z[:, :, :, :, n_z // 2:]
        shape = [tf.shape(z1)[0]] + Z.int_shape(z1)[1:]
        #############################
        # z_p = z1
        # if z_prior is not None:
        #     n_z_prior = Z.int_shape(z_prior)[3]
        #     n_z_p = Z.int_shape(z_p)[3]
        #     # w = tf.get_variable("W_split", [1, 1, n_z_prior, n_z_p], tf.float32,
        #     #                     initializer=tf.zeros_initializer())
        #     # z_p -= tf.nn.conv2d(z_prior, w, strides=[1, 1, 1, 1], padding='SAME')###########!!!!!!!!!!####### +  or - ##
        #     # z_p -= Z.conv2d_zeros('p_o', z_prior, n_z_prior, n_z_p)
        #     z_p += Z.myMLP(3, z_prior, n_z_prior, n_z_p)
        #############################
        pz = split3d_prior(y_onehot, shape,  z_prior, level)
        objective += pz.logp(z2)
        z1 = Z.squeeze3d(z1)
        eps = pz.get_eps(z2)
        return z1, z2, objective, eps,
Exemplo n.º 7
0
def split2d_prior(z, hps):
    shape = Z.int_shape(z)
    n_z2 = int(z.get_shape()[3])
    n_z1 = n_z2

    h = tf.zeros([tf.shape(z)[0]] + shape[1:3] + [2 * n_z1])

    if hps.learnprior:
        h = Z.conv2d_zeros("conv", z, 2 * n_z1)

    mean = h[:, :, :, 0::2]
    logs = h[:, :, :, 1::2]

    return Z.gaussian_diag(mean, logs)
Exemplo n.º 8
0
def split3d_reverse(name, level, z,  y_onehot, z_provided, eps, eps_std, z_prior=None):
    with tf.variable_scope(name + str(level)):

        z1 = Z.unsqueeze3d(z)

        # n_z = Z.int_shape(z1)[3]
        shape = [tf.shape(z1)[0]] + Z.int_shape(z1)[1:]

        # z_p = z1
        #############################
        # if z_prior is not None:
        #     #z_prior = Z.unsqueeze2d(z_prior)
        #     n_z_prior = Z.int_shape(z_prior)[3]
        #     # w = tf.get_variable("W_split", [1, 1, n_z_prior, n_z], tf.float32,
        #     #                     initializer=tf.zeros_initializer())
        #     # z_p -= tf.nn.conv2d(z_prior, w, strides=[1, 1, 1, 1], padding='SAME') ###########!!!!!!!!!!####### +  or - ##
        #
        #     z_p += Z.myMLP(3, z_prior, n_z_prior, n_z)
        # #############################

        pz = split3d_prior(y_onehot, shape, z_prior, level)

        if z_provided is not None:
            y_onehot2 = (y_onehot - 0.5) * (-1) + 0.5
            # y_onehot = tf.zeros_like(y_onehot)
            # y_onehot2 = tf.ones_like(y_onehot)
            pz2_ = split3d_prior(y_onehot2, shape, z_prior, level)
            # z2 = z_provided +  pz.mean - pz2_.mean
            z2 = z_provided  - pz.mean + pz2_.mean #+  0.5 * (pz.logsd - pz2_.logsd)
            # z2 = pz2_.sample2(pz.get_eps(z_provided * 0.5))
                #pz2_.mean  + 0.6 * tf.exp(pz2_.logsd)
        else:
            if eps is not None:
                # Already sampled eps
                z2 = pz.sample2(eps)
            elif eps_std is not None:
                # Sample with given eps_std
                z2 = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1, 1]))
            else:
                # Sample normally
                z2 = pz.sample

        z = tf.concat([z1, z2], 4)
        return z
Exemplo n.º 9
0
def test_derivative_fourier_conv():
    print('Testing gradients')
    shape = [128, 32, 32, 3]

    x = tf.placeholder(tf.float32, shape, name='image')
    x_np = np.random.randn(*shape).astype('float32')

    logdet = tf.zeros_like(x)[:, 0, 0, 0]

    with tf.variable_scope('test'):
        z = x
        z, logdet = fourier_conv('layer', z, logdet, reverse=False)

    with tf.variable_scope('test', reuse=True):
        w = tf.get_variable('layer/W')

    f = tf.reduce_sum(logdet)

    grad = tf.gradients(f, w)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        w_np = sess.run(w)

        grad_np = sess.run(grad, feed_dict={x: x_np})

        delta = 0.0001

        v = np.random.randn(*Z.int_shape(w))

        finite_diff = (sess.run(f, feed_dict={
            x: x_np,
            w: w_np + delta * v
        }) - sess.run(f, feed_dict={
            x: x_np,
            w: w_np - delta * v
        })) / 2 / delta

        other_side = np.sum(grad_np * v)

        print(finite_diff, other_side, finite_diff - other_side)

        print(finite_diff / other_side)
Exemplo n.º 10
0
    def _f_loss(x_A, y_A, x_B, y_B, is_training, reuse=False, init=False):
        with tf.variable_scope('model_A', reuse=reuse):
            y_onehot_A = tf.cast(tf.one_hot(y_A, hps.n_y, 1, 0), 'float32')

            # Discrete -> Continuous
            objective_A = tf.zeros_like(x_A, dtype='float32')[:, 0, 0, 0]
            z_A = preprocess(x_A)
            z_A = z_A + tf.random_uniform(tf.shape(z_A), 0, 1./hps.n_bins)
            objective_A += - np.log(hps.n_bins) * np.prod(Z.int_shape(z_A)[1:])

            # Encode
            z_A = Z.squeeze2d(z_A, 2)  # > 16x16x12
            z_A, objective_A, eps_A = encoder_A(z_A, objective_A)

            # Prior
            hps.top_shape = Z.int_shape(z_A)[1:]
            logp_A, _, _eps_A = prior("prior", y_onehot_A, hps)
            objective_A += logp_A(z_A)

            # Note that we learn the top layer so need to process z
            z_A = _eps_A(z_A)
            eps_A.append(z_A)

            # Loss of eps and flatten latent code from another model
            eps_flatten_A = tf.concat(
                [tf.contrib.layers.flatten(e) for e in eps_A], axis=-1)

        with tf.variable_scope('model_B', reuse=reuse):
            y_onehot_B = tf.cast(tf.one_hot(y_B, hps.n_y, 1, 0), 'float32')

            # Discrete -> Continuous
            objective_B = tf.zeros_like(x_B, dtype='float32')[:, 0, 0, 0]
            z_B = preprocess(x_B)
            z_B = z_B + tf.random_uniform(tf.shape(z_B), 0, 1./hps.n_bins)
            objective_B += - np.log(hps.n_bins) * np.prod(Z.int_shape(z_B)[1:])

            # Encode
            z_B = Z.squeeze2d(z_B, 2)  # > 16x16x12
            z_B, objective_B, eps_B = encoder_B(z_B, objective_B)

            # Prior
            hps.top_shape = Z.int_shape(z_B)[1:]
            logp_B, _, _eps_B = prior("prior", y_onehot_B, hps)
            objective_B += logp_B(z_B)

            # Note that we learn the top layer so need to process z
            z_B = _eps_B(z_B)
            eps_B.append(z_B)

            # Loss of eps and flatten latent code from another model
            eps_flatten_B = tf.concat(
                [tf.contrib.layers.flatten(e) for e in eps_B], axis=-1)

        code_loss = 0.0
        code_shapes = [[16, 16, 6], [8, 8, 12], [4, 4, 48]]
        if hps.code_loss_type == 'B_all':
            if not init:
                """ Decode the code from another model and compute L2 loss
                    at pixel level
                """
                def unflatten_code(fcode, code_shapes):
                    index = 0
                    code = []
                    bs = tf.shape(fcode)[0]
                    # bs = hps.local_batch_train
                    for shape in code_shapes:
                        code.append(tf.reshape(fcode[:, index:index+np.prod(shape)],
                                               tf.convert_to_tensor([bs] + shape)))
                        index += np.prod(shape)
                    return code

                code_others = unflatten_code(eps_flatten_A, code_shapes)
                # code_others[-1] is z, and code_others[:-1] is eps
                with tf.variable_scope('model_B', reuse=True):
                    _, sample, _ = prior("prior", y_onehot_B, hps)
                    code_last_others = sample(eps=code_others[-1])
                    code_decoded_others = decoder_B(
                        code_last_others, code_others[:-1])
                code_decoded = Z.unsqueeze2d(code_decoded_others, 2)
                x_B_recon = postprocess(code_decoded)
                x_B_scaled = 1/255.0 * tf.cast(x_B, tf.float32)
                x_B_recon_scaled = 1/255.0 * tf.cast(x_B_recon, tf.float32)
                if hps.code_loss_fn == 'l1':
                    code_loss = tf.reduce_mean(tf.losses.absolute_difference(
                        x_B_scaled, x_B_recon_scaled))
                elif hps.code_loss_fn == 'l2':
                    code_loss = tf.reduce_mean(tf.squared_difference(
                        x_B_scaled, x_B_recon_scaled))
                else:
                    raise NotImplementedError()
        elif hps.code_loss_type == 'code_all':
            code_loss = tf.reduce_mean(
                tf.squared_difference(eps_flatten_A, eps_flatten_B))
        elif hps.code_loss_type == 'code_last':
            dim = np.prod(code_shapes[-1])
            code_loss = tf.reduce_mean(tf.squared_difference(
                eps_flatten_A[:, -dim:], eps_flatten_B[:, -dim:]))
        else:
            raise NotImplementedError()

        with tf.variable_scope('model_A', reuse=True):
            # Generative loss
            nobj_A = - objective_A
            bits_x_A = nobj_A / (np.log(2.) * int(x_A.get_shape()[1]) * int(
                x_A.get_shape()[2]) * int(x_A.get_shape()[3]))  # bits per subpixel
            bits_y_A = tf.zeros_like(bits_x_A)
            classification_error_A = tf.ones_like(bits_x_A)

        with tf.variable_scope('model_B', reuse=True):
            # Generative loss
            nobj_B = - objective_B
            bits_x_B = nobj_B / (np.log(2.) * int(x_B.get_shape()[1]) * int(
                x_B.get_shape()[2]) * int(x_B.get_shape()[3]))  # bits per subpixel
            bits_y_B = tf.zeros_like(bits_x_B)
            classification_error_B = tf.ones_like(bits_x_B)

        return (bits_x_A, bits_y_A, classification_error_A, eps_flatten_A,
                bits_x_B, bits_y_B, classification_error_B, eps_flatten_B, code_loss)
Exemplo n.º 11
0
def invertible_conv2D_emerging(name,
                               z,
                               logdet,
                               ksize=3,
                               dilation=1,
                               reverse=False,
                               checkpoint_fn=None):
    batchsize, height, width, n_channels = Z.int_shape(z)

    assert (ksize - 1) % 2 == 0

    kcent = (ksize - 1) // 2

    with tf.variable_scope(name):
        mask_np = get_conv_square_ar_mask(
            ksize, ksize, n_channels, n_channels,
            zerodiagonal=True)[::-1, ::-1, ::-1, ::-1].copy()
        mask = tf.constant(mask_np)

        print(mask_np.transpose(3, 2, 0, 1))

        filter_shape = [ksize, ksize, n_channels, n_channels]

        w1_np = get_conv_weight_np(filter_shape)
        w2_np = get_conv_weight_np(filter_shape)
        w1 = tf.get_variable('W1', dtype=tf.float32, initializer=w1_np)
        w2 = tf.get_variable('W2', dtype=tf.float32, initializer=w2_np)
        b = tf.get_variable('b', [n_channels],
                            initializer=tf.zeros_initializer())
        b = tf.reshape(b, [1, 1, 1, -1])

        w1 = w1 * mask
        w2 = w2 * mask

        s_np = (1 + np.random.randn(n_channels) * 0.02).astype('float32')
        s = tf.get_variable('scale', dtype=tf.float32, initializer=s_np)
        s = tf.reshape(s, [1, 1, 1, n_channels])

        def flat(z):
            return tf.reshape(z, [batchsize, height * width * n_channels])

        def unflat(z):
            return tf.reshape(z, [batchsize, height, width, n_channels])

        def shift_and_log_scale_fn_volume_preserving_1(z_flat):
            z = unflat(z_flat)

            shift = tf.nn.conv2d(z,
                                 w1, [1, 1, 1, 1],
                                 dilations=[1, dilation, dilation, 1],
                                 padding='SAME',
                                 data_format='NHWC')

            shift_flat = flat(shift)

            return shift_flat, tf.zeros_like(shift_flat)

        def shift_and_log_scale_fn_volume_preserving_2(z_flat):
            z = unflat(z_flat)

            shift = tf.nn.conv2d(z,
                                 w2, [1, 1, 1, 1],
                                 dilations=[1, dilation, dilation, 1],
                                 padding='SAME',
                                 data_format='NHWC')

            shift_flat = flat(shift)

            return shift_flat, tf.zeros_like(shift_flat)

        flow1 = tfb.MaskedAutoregressiveFlow(
            shift_and_log_scale_fn_volume_preserving_1)

        flow2 = tfb.MaskedAutoregressiveFlow(
            shift_and_log_scale_fn_volume_preserving_2)

        def flip(z_flat):
            z = unflat(z_flat)
            z = z[:, ::-1, ::-1, ::-1]
            z = flat(z)
            return z

        def forward(z, logdet):
            z = z * s
            logdet += tf.reduce_sum(tf.log(tf.abs(s))) * (height * width)

            z_flat = flat(z)

            z_flat = flow1.forward(z_flat)

            z_flat = flip(z_flat)
            z_flat = flow2.forward(z_flat)
            z_flat = flip(z_flat)

            z = unflat(z_flat)

            z = z + b
            return z, logdet

        def inverse(z, logdet):
            z = z - b

            z_flat = flat(z)

            z_flat = flip(z_flat)
            z_flat = flow2.inverse(z_flat)
            z_flat = flip(z_flat)

            z_flat = flow1.inverse(z_flat)

            z = unflat(z_flat)

            z = z / s
            logdet -= tf.reduce_sum(tf.log(tf.abs(s))) * (height * width)

            z = unflat(z)

            return z, logdet

        if not reverse:
            x, logdet = forward(z, logdet)

            return x, logdet

        else:
            x, logdet = inverse(z, logdet)

            return x, logdet
Exemplo n.º 12
0
def fourier_conv(name,
                 z,
                 logdet,
                 ksize=3,
                 reverse=False,
                 checkpoint_fn=None,
                 use_fourier_forward=False):
    batchsize, height, width, n_channels = Z.int_shape(z)

    assert (ksize - 1) % 2 == 0

    with tf.variable_scope(name):
        filter_shape = [ksize, ksize, n_channels, n_channels]

        w_np = get_conv_weight_np(filter_shape)
        w = tf.get_variable('W', dtype=tf.float32, initializer=w_np)
        b = tf.get_variable('b', [n_channels],
                            initializer=tf.zeros_initializer())
        b = tf.reshape(b, [1, 1, 1, -1])

        f_shape = [height, width]

        def forward(z, w, logdet):
            padsize = (ksize - 1) // 2
            # Circular padding.
            z = tf.concat((z[:, -padsize:, :], z, z[:, :padsize, :]), axis=1)

            z = tf.concat((z[:, :, -padsize:], z, z[:, :, :padsize]), axis=2)

            # Circular convolution (due to padding.)
            z = tf.nn.conv2d(z,
                             w, [1, 1, 1, 1],
                             padding='VALID',
                             data_format='NHWC')

            # Fourier transform for log determinant.
            w_fft = tf.spectral.rfft2d(tf.transpose(
                w, [3, 2, 0, 1])[:, :, ::-1, ::-1],
                                       fft_length=f_shape,
                                       name=None)
            dlogdet = compute_logdet(w_fft, width)

            logdet += dlogdet

            z = z + b

            return z, logdet

        def forward_fourier(x, w, logdet):
            # Dimension [b, c, v, u]
            x_fft = tf.spectral.rfft2d(tf.transpose(x, [0, 3, 1, 2]),
                                       fft_length=f_shape,
                                       name=None)

            # Dimension [b, 1, c_in, v, u]
            x_fft = tf.expand_dims(x_fft, 1)

            # Dimension [c_out, c_in, v, u]
            w_fft = tf.spectral.rfft2d(tf.transpose(
                w, [3, 2, 0, 1])[:, :, ::-1, ::-1],
                                       fft_length=f_shape,
                                       name=None)

            logdet += compute_logdet(w_fft, width)

            # Dimension [1, c_out, c_in, v, u]
            w_fft = tf.expand_dims(w_fft, 0)

            z_fft = tf.reduce_sum(tf.multiply(x_fft, w_fft), axis=2)

            z = tf.spectral.irfft2d(
                z_fft,
                fft_length=f_shape,
            )

            z = tf.transpose(z, [0, 2, 3, 1])

            z = reindex(z)

            z = z + b
            return z, logdet

        def inverse(z, logdet):
            z = z - b

            z = reindex(z, reverse=True)

            # Dimension [b, c_out, v, u]
            z_fft = tf.spectral.rfft2d(tf.transpose(z, [0, 3, 1, 2]),
                                       fft_length=f_shape,
                                       name=None)

            # Dimension [b, 1, c_out, v, u]
            z_fft = tf.expand_dims(z_fft, 1)

            # Dimension [c_out, c_in, v, u]
            w_fft = tf.spectral.rfft2d(tf.transpose(
                w, [3, 2, 0, 1])[:, :, ::-1, ::-1],
                                       fft_length=f_shape,
                                       name=None)

            dlogdet = compute_logdet(w_fft, width)

            # z_fft = tf.Print(
            #     z_fft, data=[dlogdet / height / width], message='dlogdet:')

            logdet -= dlogdet

            # Dimension [v, u, c_in, c_out], channels switched because of
            # inverse.
            w_fft_inv = tf.linalg.inv(tf.transpose(w_fft, [2, 3, 0, 1]), )
            # Dimension [c_in, c_out, v, u]
            w_fft_inv = tf.transpose(w_fft_inv, [2, 3, 0, 1])

            # Dimension [1, c_in, c_out, v, u]
            w_fft_inv = tf.expand_dims(w_fft_inv, 0)

            x_fft = tf.reduce_sum(tf.multiply(z_fft, w_fft_inv), axis=2)

            x = tf.spectral.irfft2d(
                x_fft,
                fft_length=f_shape,
            )

            x = tf.transpose(x, [0, 2, 3, 1])

            return x, logdet

        if not reverse:
            x = z

            if use_fourier_forward:
                z, logdet = forward_fourier(x, w, logdet)
            else:
                z, logdet = forward(x, w, logdet)

            return z, logdet

        else:
            z, logdet = inverse(z, logdet)

            return z, logdet
Exemplo n.º 13
0
def revnet2d_step(name, z, logdet, hps, reverse):
    with tf.variable_scope(name):

        shape = Z.int_shape(z)
        n_z = shape[3]
        assert n_z % 2 == 0

        if not reverse:

            z, logdet = Z.actnorm("actnorm", z, logdet=logdet)

            if hps.flow_permutation == 0:
                z = Z.reverse_features("reverse", z)
            elif hps.flow_permutation == 1:
                z = Z.shuffle_features("shuffle", z)
            elif hps.flow_permutation == 2:
                z, logdet = invertible_1x1_conv(
                    "invconv", z, logdet, decomposition=hps.decomposition)
            elif hps.flow_permutation == 3:
                z, logdet = invertible_1x1_conv(
                    "invconv", z, logdet, decomposition=hps.decomposition)
                z, logdet = invertible_conv2D_emerging(
                    "emerging", z, logdet, checkpoint_fn=checkpoint)
            elif hps.flow_permutation == 4:
                z, logdet = fourier_conv('fourier', z, logdet)
            elif hps.flow_permutation == 5:
                z, logdet = invertible_1x1_conv(
                    "invconv", z, logdet, decomposition=hps.decomposition)
                z, logdet = maf_three('maf1',
                                      z,
                                      logdet,
                                      depth=96,
                                      is_upper=False)
                z, logdet = maf_three('maf2',
                                      z,
                                      logdet,
                                      depth=96,
                                      is_upper=True)
            else:
                raise Exception()

            z1 = z[:, :, :, :n_z // 2]
            z2 = z[:, :, :, n_z // 2:]

            if hps.flow_coupling == 0:
                z2 += f("f1", z1, hps.width)
            elif hps.flow_coupling == 1:
                h = f("f1", z1, hps.width, n_z)
                shift = h[:, :, :, 0::2]
                # scale = tf.exp(h[:, :, :, 1::2])
                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
                logscale = tf.log_sigmoid(h[:, :, :, 1::2] + 2.)
                z2 += shift
                z2 *= scale
                logdet += tf.reduce_sum(logscale, axis=[1, 2, 3])
            else:
                raise Exception()

            z = tf.concat([z1, z2], 3)

        else:

            z1 = z[:, :, :, :n_z // 2]
            z2 = z[:, :, :, n_z // 2:]

            if hps.flow_coupling == 0:
                z2 -= f("f1", z1, hps.width)
            elif hps.flow_coupling == 1:
                h = f("f1", z1, hps.width, n_z)
                shift = h[:, :, :, 0::2]
                # scale = tf.exp(h[:, :, :, 1::2])
                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
                logscale = tf.log_sigmoid(h[:, :, :, 1::2] + 2.)
                z2 /= scale
                z2 -= shift
                logdet -= tf.reduce_sum(logscale, axis=[1, 2, 3])
            else:
                raise Exception()

            z = tf.concat([z1, z2], 3)

            if hps.flow_permutation == 0:
                z = Z.reverse_features("reverse", z, reverse=True)
            elif hps.flow_permutation == 1:
                z = Z.shuffle_features("shuffle", z, reverse=True)
            elif hps.flow_permutation == 2:
                z, logdet = invertible_1x1_conv(
                    "invconv",
                    z,
                    logdet,
                    reverse=True,
                    decomposition=hps.decomposition)
            elif hps.flow_permutation == 3:
                z, logdet = invertible_conv2D_emerging("emerging",
                                                       z,
                                                       logdet,
                                                       reverse=True)
                z, logdet = invertible_1x1_conv(
                    "invconv",
                    z,
                    logdet,
                    reverse=True,
                    decomposition=hps.decomposition)
            elif hps.flow_permutation == 4:
                z, logdet = fourier_conv('fourier', z, logdet, reverse=True)
            elif hps.flow_permutation == 5:
                z, logdet = maf_three('maf2',
                                      z,
                                      logdet,
                                      depth=96,
                                      is_upper=True,
                                      reverse=True)
                z, logdet = maf_three('maf1',
                                      z,
                                      logdet,
                                      depth=96,
                                      is_upper=False,
                                      reverse=True)
                z, logdet = invertible_1x1_conv(
                    "invconv",
                    z,
                    logdet,
                    decomposition=hps.decomposition,
                    reverse=True)
            else:
                raise Exception()

            z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True)

    return z, logdet
Exemplo n.º 14
0
def invertible_1x1_conv(name, z, logdet, reverse=False):

    if True:  # Set to "False" to use the LU-decomposed version

        with tf.variable_scope(name):

            shape = Z.int_shape(z)
            w_shape = [shape[3], shape[3]]

            # Sample a random orthogonal matrix:
            w_init = np.linalg.qr(np.random.randn(
                *w_shape))[0].astype('float32')

            w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)

            # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2]
            dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant(
                tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2]

            if not reverse:

                _w = tf.reshape(w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet += dlogdet

                return z, logdet
            else:

                _w = tf.matrix_inverse(w)
                _w = tf.reshape(_w, [1, 1]+w_shape)
                z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet -= dlogdet

                return z, logdet

    else:

        # LU-decomposed version
        shape = Z.int_shape(z)
        with tf.variable_scope(name):

            dtype = 'float64'

            # Random orthogonal matrix:
            import scipy
            np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[
                0].astype('float32')

            np_p, np_l, np_u = scipy.linalg.lu(np_w)
            np_s = np.diag(np_u)
            np_sign_s = np.sign(np_s)
            np_log_s = np.log(abs(np_s))
            np_u = np.triu(np_u, k=1)

            p = tf.get_variable("P", initializer=np_p, trainable=False)
            l = tf.get_variable("L", initializer=np_l)
            sign_s = tf.get_variable(
                "sign_S", initializer=np_sign_s, trainable=False)
            log_s = tf.get_variable("log_S", initializer=np_log_s)
            # S = tf.get_variable("S", initializer=np_s)
            u = tf.get_variable("U", initializer=np_u)

            p = tf.cast(p, dtype)
            l = tf.cast(l, dtype)
            sign_s = tf.cast(sign_s, dtype)
            log_s = tf.cast(log_s, dtype)
            u = tf.cast(u, dtype)

            w_shape = [shape[3], shape[3]]

            l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1)
            l = l * l_mask + tf.eye(*w_shape, dtype=dtype)
            u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
            w = tf.matmul(p, tf.matmul(l, u))

            if True:
                u_inv = tf.matrix_inverse(u)
                l_inv = tf.matrix_inverse(l)
                p_inv = tf.matrix_inverse(p)
                w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv))
            else:
                w_inv = tf.matrix_inverse(w)

            w = tf.cast(w, tf.float32)
            w_inv = tf.cast(w_inv, tf.float32)
            log_s = tf.cast(log_s, tf.float32)

            if not reverse:

                w = tf.reshape(w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z, w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2])

                return z, logdet
            else:

                w_inv = tf.reshape(w_inv, [1, 1]+w_shape)
                z = tf.nn.conv2d(
                    z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC')
                logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2])

                return z, logdet
Exemplo n.º 15
0
def invertible_1x1_conv(name,
                        z,
                        logdet,
                        decomposition=None,
                        reverse=False,
                        unit_testing=False):
    shape = Z.int_shape(z)
    w_shape = [shape[3], shape[3]]

    if decomposition is None or decomposition == '':
        with tf.variable_scope(name):
            # Sample a random orthogonal matrix:
            w_init = np.linalg.qr(
                np.random.randn(*w_shape))[0].astype('float32')

            w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)

            # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2]
            dlogdet = tf.cast(
                tf.log(abs(tf.matrix_determinant(tf.cast(w, 'float64')))),
                'float32') * shape[1] * shape[2]

            if not reverse:

                _w = tf.reshape(w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z,
                                 _w, [1, 1, 1, 1],
                                 'SAME',
                                 data_format='NHWC')
                logdet += dlogdet

                return z, logdet
            else:
                # z = tf.Print(
                #     z,
                #     data=[dlogdet / shape[1] / shape[2]],
                #     message='logdet invconv foreach spatial location: ')

                _w = tf.matrix_inverse(w)
                _w = tf.reshape(_w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z,
                                 _w, [1, 1, 1, 1],
                                 'SAME',
                                 data_format='NHWC')
                logdet -= dlogdet

                return z, logdet

    elif decomposition == 'PLU' or decomposition == 'LU':
        # LU-decomposed version
        shape = Z.int_shape(z)
        with tf.variable_scope(name):

            dtype = 'float64'

            # Random orthogonal matrix:
            import scipy
            np_w = scipy.linalg.qr(np.random.randn(
                shape[3], shape[3]))[0].astype('float32')

            np_p, np_l, np_u = scipy.linalg.lu(np_w)
            np_s = np.diag(np_u)
            np_sign_s = np.sign(np_s)
            np_log_s = np.log(abs(np_s))
            np_u = np.triu(np_u, k=1)

            p = tf.get_variable("P", initializer=np_p, trainable=False)
            l = tf.get_variable("L", initializer=np_l)
            sign_s = tf.get_variable("sign_S",
                                     initializer=np_sign_s,
                                     trainable=False)
            log_s = tf.get_variable("log_S", initializer=np_log_s)
            # S = tf.get_variable("S", initializer=np_s)
            u = tf.get_variable("U", initializer=np_u)

            p = tf.cast(p, dtype)
            l = tf.cast(l, dtype)
            sign_s = tf.cast(sign_s, dtype)
            log_s = tf.cast(log_s, dtype)
            u = tf.cast(u, dtype)

            l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1)
            l = l * l_mask + tf.eye(*w_shape, dtype=dtype)
            u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
            w = tf.matmul(p, tf.matmul(l, u))

            if True:
                u_inv = tf.matrix_inverse(u)
                l_inv = tf.matrix_inverse(l)
                p_inv = tf.matrix_inverse(p)
                w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv))
            else:
                w_inv = tf.matrix_inverse(w)

            w = tf.cast(w, tf.float32)
            w_inv = tf.cast(w_inv, tf.float32)
            log_s = tf.cast(log_s, tf.float32)

            if not reverse:

                w = tf.reshape(w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z,
                                 w, [1, 1, 1, 1],
                                 'SAME',
                                 data_format='NHWC')
                logdet += tf.reduce_sum(log_s) * (shape[1] * shape[2])

                return z, logdet
            else:

                w_inv = tf.reshape(w_inv, [1, 1] + w_shape)
                z = tf.nn.conv2d(z,
                                 w_inv, [1, 1, 1, 1],
                                 'SAME',
                                 data_format='NHWC')
                logdet -= tf.reduce_sum(log_s) * (shape[1] * shape[2])

                return z, logdet

    elif decomposition == 'QR':
        with tf.variable_scope(name):
            np_s = np.ones(shape[3], dtype='float32')
            np_u = np.zeros((shape[3], shape[3]), dtype='float32')

            if unit_testing:
                np_s = 1 + 0.02 * np.random.randn(shape[3]).astype('float32')
                np_u = np.random.randn(shape[3], shape[3]).astype('float32')

            np_u = np.triu(np_u, k=1).astype('float32')
            u_mask = np.triu(np.ones(w_shape, dtype='float32'), 1)

            s = tf.get_variable("S", initializer=np_s)
            u = tf.get_variable("U", initializer=np_u)

            log_s = tf.log(tf.abs(s))

            r = u * u_mask + tf.diag(s)

            # Householder transformations
            I = tf.eye(shape[3])
            q = I
            for i in range(shape[3]):
                v_np = np.random.randn(shape[3], 1).astype('float32')
                v = tf.get_variable("v_{}".format(i), initializer=v_np)
                vT = tf.transpose(v)
                q_i = I - 2 * tf.matmul(v, vT) / tf.matmul(vT, v)

                q = tf.matmul(q, q_i)

            # Modified Gram–Schmidt process
            # def inner(a, b):
            #     return tf.reduce_sum(a * b)

            # def proj(v, u):
            #     return u * inner(v, u) / inner(u, u)

            # q = []
            # for i in range(shape[3]):
            #     v_np = np.random.randn(shape[3], 1).astype('float32')
            #     v = tf.get_variable("v_{}".format(i), initializer=v_np)
            #     for j in range(i):
            #         p = proj(v, q[j])
            #         v = v - proj(v, q[j])
            #     q.append(v)
            # q = tf.concat(q, axis=1)
            # q = q / tf.norm(q, axis=0, keepdims=True)

            q_inv = tf.transpose(q)
            r_inv = tf.matrix_inverse(r)

            w = tf.matmul(q, r)
            w_inv = tf.matmul(r_inv, q_inv)

            if not reverse:
                w = tf.reshape(w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z,
                                 w, [1, 1, 1, 1],
                                 'SAME',
                                 data_format='NHWC')
                logdet += tf.reduce_sum(log_s) * (shape[1] * shape[2])

                return z, logdet
            else:

                w_inv = tf.reshape(w_inv, [1, 1] + w_shape)
                z = tf.nn.conv2d(z,
                                 w_inv, [1, 1, 1, 1],
                                 'SAME',
                                 data_format='NHWC')
                logdet -= tf.reduce_sum(log_s) * (shape[1] * shape[2])

                return z, logdet

    else:
        raise ValueError('Unkown decomposition: {}'.format(decomposition))
Exemplo n.º 16
0
def revnet2d_step(name, z, logdet, hps, reverse):
    with tf.variable_scope(name):

        shape = Z.int_shape(z)
        n_z = shape[3]
        assert n_z % 2 == 0

        if not reverse:

            z, logdet = Z.actnorm("actnorm", z, logdet=logdet)

            if hps.flow_permutation == 0:
                z = Z.reverse_features("reverse", z)
            elif hps.flow_permutation == 1:
                z = Z.shuffle_features("shuffle", z)
            elif hps.flow_permutation == 2:
                z, logdet = invertible_1x1_conv("invconv", z, logdet)
            else:
                raise Exception()

            z1 = z[:, :, :, :n_z // 2]
            z2 = z[:, :, :, n_z // 2:]

            if hps.flow_coupling == 0:
                z2 += f("f1", z1, hps.width)
            elif hps.flow_coupling == 1:
                h = f("f1", z1, hps.width, n_z)
                shift = h[:, :, :, 0::2]
                # scale = tf.exp(h[:, :, :, 1::2])
                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
                z2 += shift
                z2 *= scale
                logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
            else:
                raise Exception()

            z = tf.concat([z1, z2], 3)

        else:

            z1 = z[:, :, :, :n_z // 2]
            z2 = z[:, :, :, n_z // 2:]

            if hps.flow_coupling == 0:
                z2 -= f("f1", z1, hps.width)
            elif hps.flow_coupling == 1:
                h = f("f1", z1, hps.width, n_z)
                shift = h[:, :, :, 0::2]
                # scale = tf.exp(h[:, :, :, 1::2])
                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
                z2 /= scale
                z2 -= shift
                logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
            else:
                raise Exception()

            z = tf.concat([z1, z2], 3)

            if hps.flow_permutation == 0:
                z = Z.reverse_features("reverse", z, reverse=True)
            elif hps.flow_permutation == 1:
                z = Z.shuffle_features("shuffle", z, reverse=True)
            elif hps.flow_permutation == 2:
                z, logdet = invertible_1x1_conv(
                    "invconv", z, logdet, reverse=True)
            else:
                raise Exception()

            z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True)

    return z, logdet
Exemplo n.º 17
0
def invertible_conv2D_emerging_1x1(name,
                                   z,
                                   logdet,
                                   ksize=3,
                                   dilation=1,
                                   reverse=False,
                                   checkpoint_fn=None,
                                   decomposition=None,
                                   unit_testing=False):
    shape = Z.int_shape(z)
    batchsize, height, width, n_channels = shape

    assert (ksize - 1) % 2 == 0

    kcent = (ksize - 1) // 2

    with tf.variable_scope(name):
        if decomposition is None or decomposition == '':
            # Sample a random orthogonal matrix:
            w_init = np.linalg.qr(np.random.randn(
                shape[3], shape[3]))[0].astype('float32')
            w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)
            dlogdet = tf.cast(
                tf.log(abs(tf.matrix_determinant(tf.cast(w, 'float64')))),
                'float32') * shape[1] * shape[2]
            w_inv = tf.matrix_inverse(w)

        elif decomposition == 'PLU' or decomposition == 'LU':
            # LU-decomposed version
            dtype = 'float64'

            # Random orthogonal matrix:
            import scipy
            np_w = scipy.linalg.qr(np.random.randn(
                shape[3], shape[3]))[0].astype('float32')

            np_p, np_l, np_u = scipy.linalg.lu(np_w)
            np_s = np.diag(np_u)
            np_sign_s = np.sign(np_s)
            np_log_s = np.log(abs(np_s))
            np_u = np.triu(np_u, k=1)

            p = tf.get_variable("P", initializer=np_p, trainable=False)
            l = tf.get_variable("L", initializer=np_l)
            sign_s = tf.get_variable("sign_S",
                                     initializer=np_sign_s,
                                     trainable=False)
            log_s = tf.get_variable("log_S", initializer=np_log_s)
            u = tf.get_variable("U", initializer=np_u)

            p = tf.cast(p, 'float64')
            l = tf.cast(l, 'float64')
            sign_s = tf.cast(sign_s, 'float64')
            log_s = tf.cast(log_s, 'float64')
            u = tf.cast(u, 'float64')

            l_mask = np.tril(np.ones([shape[3], shape[3]], dtype=dtype), -1)
            l = l * l_mask + tf.eye(shape[3], dtype=dtype)
            u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
            w = tf.matmul(p, tf.matmul(l, u))

            u_inv = tf.matrix_inverse(u)
            l_inv = tf.matrix_inverse(l)
            p_inv = tf.matrix_inverse(p)
            w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv))

            w = tf.cast(w, tf.float32)
            w_inv = tf.cast(w_inv, tf.float32)
            log_s = tf.cast(log_s, tf.float32)

            dlogdet = tf.reduce_sum(log_s) * (shape[1] * shape[2])

        elif decomposition == 'QR':
            np_s = np.ones(shape[3], dtype='float32')
            np_u = np.zeros((shape[3], shape[3]), dtype='float32')

            if unit_testing:
                np_s = 1 + 0.02 * np.random.randn(shape[3]).astype('float32')
                np_u = np.random.randn(shape[3], shape[3]).astype('float32')

            np_u = np.triu(np_u, k=1).astype('float32')
            u_mask = np.triu(np.ones([shape[3], shape[3]], dtype='float32'), 1)

            s = tf.get_variable("S", initializer=np_s)
            u = tf.get_variable("U", initializer=np_u)

            log_s = tf.log(tf.abs(s))

            r = u * u_mask + tf.diag(s)

            # Householder transformations
            I = tf.eye(shape[3])
            q = I
            for i in range(shape[3]):
                v_np = np.random.randn(shape[3], 1).astype('float32')
                v = tf.get_variable("v_{}".format(i), initializer=v_np)
                vT = tf.transpose(v)
                q_i = I - 2 * tf.matmul(v, vT) / tf.matmul(vT, v)

                q = tf.matmul(q, q_i)

            # Modified Gram–Schmidt process
            # def inner(a, b):
            #     return tf.reduce_sum(a * b)

            # def proj(v, u):
            #     return u * inner(v, u) / inner(u, u)

            # q = []
            # for i in range(shape[3]):
            #     v_np = np.random.randn(shape[3], 1).astype('float32')
            #     v = tf.get_variable("v_{}".format(i), initializer=v_np)
            #     for j in range(i):
            #         p = proj(v, q[j])
            #         v = v - proj(v, q[j])
            #     q.append(v)
            # q = tf.concat(q, axis=1)
            # q = q / tf.norm(q, axis=0, keepdims=True)

            q_inv = tf.transpose(q)
            r_inv = tf.matrix_inverse(r)

            w = tf.matmul(q, r)
            w_inv = tf.matmul(r_inv, q_inv)

            dlogdet = tf.reduce_sum(log_s) * (shape[1] * shape[2])
        else:
            raise ValueError('Unknown decomposition: {}'.format(decomposition))

        mask_np = get_conv_square_ar_mask(ksize, ksize, n_channels, n_channels)

        mask_upsidedown_np = mask_np[::-1, ::-1, ::-1, ::-1].copy()

        mask = tf.constant(mask_np)
        mask_upsidedown = tf.constant(mask_upsidedown_np)

        filter_shape = [ksize, ksize, n_channels, n_channels]

        w1_np = get_conv_weight_np(filter_shape)
        w2_np = get_conv_weight_np(filter_shape)
        w1 = tf.get_variable('W1', dtype=tf.float32, initializer=w1_np)
        w2 = tf.get_variable('W2', dtype=tf.float32, initializer=w2_np)
        b = tf.get_variable('b', [n_channels],
                            initializer=tf.zeros_initializer())
        b = tf.reshape(b, [1, 1, 1, -1])

        w1 = w1 * mask
        w2 = w2 * mask_upsidedown

        def log_abs_diagonal(w):
            return tf.log(tf.abs(tf.diag_part(w[kcent, kcent])))

        def forward(z, logdet):
            w_ = tf.reshape(w, [1, 1] + [shape[3], shape[3]])
            z = tf.nn.conv2d(z, w_, [1, 1, 1, 1], 'SAME', data_format='NHWC')

            logdet += dlogdet

            z = tf.nn.conv2d(z,
                             w1, [1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='SAME',
                             data_format='NHWC')
            logdet += tf.reduce_sum(log_abs_diagonal(w1)) * (height * width)

            if checkpoint_fn is not None:
                checkpoint_fn(z, logdet)

            z = tf.nn.conv2d(z,
                             w2, [1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='SAME',
                             data_format='NHWC')
            logdet += tf.reduce_sum(log_abs_diagonal(w2)) * (height * width)

            if checkpoint_fn is not None:
                checkpoint_fn(z, logdet)

            z = z + b
            return z, logdet

        def forward_fast(z, logdet):
            """
            Convolution with [(k+1) // 2]^2 filters.
            """
            # Smaller versions of w1, w2.
            w1_s = w1[kcent:, kcent:, :, :]
            w2_s = w2[:-kcent, :-kcent, :, :]

            pad = kcent * dilation

            # standard filter shape: [v, u, c_in, c_out]
            # standard fmap shape: [b, h, w, c]

            w_ = tf.transpose(tf.reshape(w, [1, 1] + [shape[3], shape[3]]),
                              (0, 1, 3, 2))
            w_equiv = tf.nn.conv2d(tf.transpose(w1_s, (3, 0, 1, 2)),
                                   w_, [1, 1, 1, 1],
                                   padding='SAME')

            w_equiv = tf.transpose(w_equiv, (1, 2, 3, 0))

            z = tf.pad(z, [[0, 0], [0, pad], [0, pad], [0, 0]], 'CONSTANT')
            z = tf.nn.conv2d(z,
                             w_equiv, [1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='VALID',
                             data_format='NHWC')

            logdet += tf.reduce_sum(log_abs_diagonal(w1)) * (height * width)
            if checkpoint_fn is not None:
                checkpoint_fn(z, logdet)

            z = tf.pad(z, [[0, 0], [pad, 0], [pad, 0], [0, 0]], 'CONSTANT')

            z = tf.nn.conv2d(z,
                             w2_s, [1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='VALID',
                             data_format='NHWC')

            logdet += tf.reduce_sum(log_abs_diagonal(w2)) * (height * width)
            if checkpoint_fn is not None:
                checkpoint_fn(z, logdet)

            z = z + b
            return z, logdet

        if not reverse:
            x, logdet = forward_fast(z, logdet)
            # x_, _ = forward(z, logdet)

            # x = tf.Print(
            #     x, data=[tf.reduce_mean(tf.square(x - x_))], message='diff')

            return x, logdet

        else:
            logdet -= dlogdet
            logdet -= tf.reduce_sum(log_abs_diagonal(w2)) * (height * width)

            x = tf.py_func(
                Inverse(is_upper=1, dilation=dilation),
                inp=[z, w2, b],
                Tout=tf.float32,
                stateful=True,
                name='conv2dinverse2',
            )

            logdet -= tf.reduce_sum(log_abs_diagonal(w1)) * (height * width)

            x = tf.py_func(
                Inverse(is_upper=0, dilation=dilation),
                inp=[x, w1, tf.zeros_like(b)],
                Tout=tf.float32,
                stateful=True,
                name='conv2dinverse1',
            )

            x.set_shape(z.get_shape())

            z_recon, _ = forward_fast(x, tf.zeros_like(logdet))

            w_inv = tf.reshape(w_inv, [1, 1] + [shape[3], shape[3]])
            x = tf.nn.conv2d(x,
                             w_inv, [1, 1, 1, 1],
                             'SAME',
                             data_format='NHWC')
            logdet -= dlogdet

            # mse = tf.sqrt(tf.reduce_mean(tf.pow(z_recon - z, 2)))

            # x = tf.Print(
            #     x,
            #     data=[mse],
            #     message='RMSE of inverse',
            # )

            return x, logdet
Exemplo n.º 18
0
def invertible_ar_conv2D(
    name,
    z,
    logdet,
    is_upper,
    ksize=3,
    dilation=1,
    reverse=False,
):
    shape = Z.int_shape(z)
    n_channels = shape[3]
    kcent = (ksize - 1) // 2

    with tf.variable_scope(name):
        mask_np = get_conv_ar_mask(ksize, ksize, n_channels, n_channels)
        if is_upper:
            mask_np = mask_np[::-1, ::-1, ::-1, ::-1].copy()

        mask = tf.constant(mask_np)

        filter_shape = [ksize, ksize, n_channels, n_channels]

        weight_np = get_conv_weight_np(filter_shape)

        w = tf.get_variable('W', dtype=tf.float32, initializer=weight_np)
        b = tf.get_variable('b', [n_channels],
                            initializer=tf.zeros_initializer())
        b = tf.reshape(b, [1, 1, 1, -1])

        w = mask * w

        log_abs_diagonal = tf.log(tf.abs(tf.diag_part(w[kcent, kcent])))

        if not reverse:
            z = tf.nn.conv2d(z,
                             w,
                             strides=[1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='SAME',
                             data_format='NHWC') + b
            logdet += tf.reduce_sum(log_abs_diagonal) * (shape[1] * shape[2])

            return z, logdet
        else:
            logdet -= tf.reduce_sum(log_abs_diagonal) * (shape[1] * shape[2])

            x = tf.py_func(
                Inverse(is_upper=is_upper, dilation=dilation),
                inp=[z, w, b],
                Tout=tf.float32,
                stateful=True,
                name='conv2dinverse',
            )

            z_recon = tf.nn.conv2d(
                x, w, [1, 1, 1, 1], padding='SAME', data_format='NHWC') + b

            mse = tf.sqrt(tf.reduce_mean(tf.pow(z_recon - z, 2)))

            x = tf.Print(
                x,
                data=[mse],
                message='RMSE of inverse',
            )

            x.set_shape(z.get_shape())

            return x, logdet
Exemplo n.º 19
0
def invertible_conv2D_emerging(name,
                               z,
                               logdet,
                               ksize=3,
                               dilation=1,
                               reverse=False,
                               checkpoint_fn=None):
    batchsize, height, width, n_channels = Z.int_shape(z)

    assert (ksize - 1) % 2 == 0

    kcent = (ksize - 1) // 2

    with tf.variable_scope(name):
        mask_np = get_conv_square_ar_mask(ksize, ksize, n_channels, n_channels)

        mask_upsidedown_np = mask_np[::-1, ::-1, ::-1, ::-1].copy()

        mask = tf.constant(mask_np)
        mask_upsidedown = tf.constant(mask_upsidedown_np)

        filter_shape = [ksize, ksize, n_channels, n_channels]

        w1_np = get_conv_weight_np(filter_shape)
        w2_np = get_conv_weight_np(filter_shape)
        w1 = tf.get_variable('W1', dtype=tf.float32, initializer=w1_np)
        w2 = tf.get_variable('W2', dtype=tf.float32, initializer=w2_np)
        b = tf.get_variable('b', [n_channels],
                            initializer=tf.zeros_initializer())
        b = tf.reshape(b, [1, 1, 1, -1])

        w1 = w1 * mask
        w2 = w2 * mask_upsidedown

        def log_abs_diagonal(w):
            return tf.log(tf.abs(tf.diag_part(w[kcent, kcent])))

        def forward(z, logdet):
            z = tf.nn.conv2d(z,
                             w1, [1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='SAME',
                             data_format='NHWC')
            logdet += tf.reduce_sum(log_abs_diagonal(w1)) * (height * width)

            if checkpoint_fn is not None:
                checkpoint_fn(z, logdet)

            z = tf.nn.conv2d(z,
                             w2, [1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='SAME',
                             data_format='NHWC')
            logdet += tf.reduce_sum(log_abs_diagonal(w2)) * (height * width)

            if checkpoint_fn is not None:
                checkpoint_fn(z, logdet)

            z = z + b
            return z, logdet

        def forward_fast(z, logdet):
            """
            Convolution with [(k+1) // 2]^2 filters.
            """
            # Smaller versions of w1, w2.
            w1_s = w1[kcent:, kcent:, :, :]
            w2_s = w2[:-kcent, :-kcent, :, :]

            pad = kcent * dilation

            z = tf.pad(z, [[0, 0], [0, pad], [0, pad], [0, 0]], 'CONSTANT')
            z = tf.nn.conv2d(z,
                             w1_s, [1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='VALID',
                             data_format='NHWC')

            logdet += tf.reduce_sum(log_abs_diagonal(w1)) * (height * width)
            if checkpoint_fn is not None:
                checkpoint_fn(z, logdet)

            z = tf.pad(z, [[0, 0], [pad, 0], [pad, 0], [0, 0]], 'CONSTANT')

            z = tf.nn.conv2d(z,
                             w2_s, [1, 1, 1, 1],
                             dilations=[1, dilation, dilation, 1],
                             padding='VALID',
                             data_format='NHWC')

            logdet += tf.reduce_sum(log_abs_diagonal(w2)) * (height * width)
            if checkpoint_fn is not None:
                checkpoint_fn(z, logdet)

            z = z + b
            return z, logdet

        if not reverse:
            x, logdet = forward_fast(z, logdet)

            return x, logdet

        else:
            logdet -= tf.reduce_sum(log_abs_diagonal(w2)) * (height * width)

            x = tf.py_func(
                Inverse(is_upper=1, dilation=dilation),
                inp=[z, w2, b],
                Tout=tf.float32,
                stateful=True,
                name='conv2dinverse2',
            )

            logdet -= tf.reduce_sum(log_abs_diagonal(w1)) * (height * width)

            x = tf.py_func(
                Inverse(is_upper=0, dilation=dilation),
                inp=[x, w1, tf.zeros_like(b)],
                Tout=tf.float32,
                stateful=True,
                name='conv2dinverse1',
            )

            x.set_shape(z.get_shape())

            z_recon, _ = forward_fast(x, tf.zeros_like(logdet))

            # mse = tf.sqrt(tf.reduce_mean(tf.pow(z_recon - z, 2)))

            # x = tf.Print(
            #     x,
            #     data=[mse],
            #     message='RMSE of inverse',
            # )

            return x, logdet
Exemplo n.º 20
0
def invertible_1x1_conv(name, z, logdet, reverse=False):

    if True:  # Set to "False" to use the LU-decomposed version

        with tf.variable_scope(name):

            shape = Z.int_shape(z)
            w_shape = [shape[3], shape[3]]

            # Sample a random orthogonal matrix:
            w_init = np.linalg.qr(np.random.randn(
                *w_shape))[0].astype('float32')

            w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)

            # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2]
            dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant(
                tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2]

            if not reverse:

                _w = tf.reshape(w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet += dlogdet

                return z, logdet
            else:

                _w = tf.matrix_inverse(w)
                _w = tf.reshape(_w, [1, 1]+w_shape)
                z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet -= dlogdet

                return z, logdet

    else:

        # LU-decomposed version
        shape = Z.int_shape(z)
        with tf.variable_scope(name):

            dtype = 'float64'

            # Random orthogonal matrix:
            import scipy
            np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[
                0].astype('float32')

            np_p, np_l, np_u = scipy.linalg.lu(np_w)
            np_s = np.diag(np_u)
            np_sign_s = np.sign(np_s)
            np_log_s = np.log(abs(np_s))
            np_u = np.triu(np_u, k=1)

            p = tf.get_variable("P", initializer=np_p, trainable=False)
            l = tf.get_variable("L", initializer=np_l) # noqa
            sign_s = tf.get_variable(
                "sign_S", initializer=np_sign_s, trainable=False)
            log_s = tf.get_variable("log_S", initializer=np_log_s)
            # S = tf.get_variable("S", initializer=np_s)
            u = tf.get_variable("U", initializer=np_u)

            p = tf.cast(p, dtype)
            l = tf.cast(l, dtype)  # noqa
            sign_s = tf.cast(sign_s, dtype)
            log_s = tf.cast(log_s, dtype)
            u = tf.cast(u, dtype)

            w_shape = [shape[3], shape[3]]

            l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1)
            l = l * l_mask + tf.eye(*w_shape, dtype=dtype)  # noqa
            u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
            w = tf.matmul(p, tf.matmul(l, u))

            if True:
                u_inv = tf.matrix_inverse(u)
                l_inv = tf.matrix_inverse(l)
                p_inv = tf.matrix_inverse(p)
                w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv))
            else:
                w_inv = tf.matrix_inverse(w)

            w = tf.cast(w, tf.float32)
            w_inv = tf.cast(w_inv, tf.float32)
            log_s = tf.cast(log_s, tf.float32)

            if not reverse:

                w = tf.reshape(w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z, w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2])

                return z, logdet
            else:

                w_inv = tf.reshape(w_inv, [1, 1]+w_shape)
                z = tf.nn.conv2d(
                    z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC')
                logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2])

                return z, logdet
Exemplo n.º 21
0
def revnet2d_step(name, z, logdet, hps, reverse):
    with tf.variable_scope(name):

        shape = Z.int_shape(z)
        n_z = shape[3]
        assert n_z % 2 == 0

        if not reverse:

            z, logdet = Z.actnorm("actnorm", z, logdet=logdet)

            if hps.flow_permutation == 0:
                z = Z.reverse_features("reverse", z)
            elif hps.flow_permutation == 1:
                z = Z.shuffle_features("shuffle", z)
            elif hps.flow_permutation == 2:
                z, logdet = invertible_1x1_conv("invconv", z, logdet)
            else:
                raise Exception()

            z1 = z[:, :, :, :n_z // 2]
            z2 = z[:, :, :, n_z // 2:]

            if hps.flow_coupling == 0:
                z2 += f("f1", z1, hps.width)
            elif hps.flow_coupling == 1:
                h = f("f1", z1, hps.width, n_z)
                shift = h[:, :, :, 0::2]
                # scale = tf.exp(h[:, :, :, 1::2])
                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
                z2 += shift
                z2 *= scale
                logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
            else:
                raise Exception()

            z = tf.concat([z1, z2], 3)

        else:

            z1 = z[:, :, :, :n_z // 2]
            z2 = z[:, :, :, n_z // 2:]

            if hps.flow_coupling == 0:
                z2 -= f("f1", z1, hps.width)
            elif hps.flow_coupling == 1:
                h = f("f1", z1, hps.width, n_z)
                shift = h[:, :, :, 0::2]
                # scale = tf.exp(h[:, :, :, 1::2])
                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
                z2 /= scale
                z2 -= shift
                logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
            else:
                raise Exception()

            z = tf.concat([z1, z2], 3)

            if hps.flow_permutation == 0:
                z = Z.reverse_features("reverse", z, reverse=True)
            elif hps.flow_permutation == 1:
                z = Z.shuffle_features("shuffle", z, reverse=True)
            elif hps.flow_permutation == 2:
                z, logdet = invertible_1x1_conv(
                    "invconv", z, logdet, reverse=True)
            else:
                raise Exception()

            z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True)

    return z, logdet