コード例 #1
ファイル: models.py プロジェクト: omidsakhi/tpu_glow
def revnet2d_step(name, z, logdet, cfg, reverse):
    shape = ops.int_shape(z)
    n_z = shape[3]
    assert n_z % 2 == 0
    with tf.variable_scope(name):
        if not reverse:
            z, logdet = ops.scale_bias("actnorm", z, logdet=logdet)
            z = ops.reverse_features("reverse", z)
            #z, logdet = invertible_1x1_conv("invconv", z, logdet)
            z1 = z[:, :, :, :n_z // 2]
            z2 = z[:, :, :, n_z // 2:]
            h = f_("f1", z1, cfg, n_z)
            shift = h[:, :, :, 0::2]
            logs = h[:, :, :, 1::2] / 4.0
            z2 += shift
            z2 *= tf.exp(logs)
            logdet += tf.reduce_sum(logs, axis=[1, 2, 3])
            z = tf.concat([z1, z2], 3)
            z1 = z[:, :, :, :n_z // 2]
            z2 = z[:, :, :, n_z // 2:]
            h = f_("f1", z1, cfg, n_z)
            shift = h[:, :, :, 0::2]
            logs = h[:, :, :, 1::2] / 4.0
            z2 *= tf.exp(-1.0 * logs)
            z2 -= shift
            logdet -= tf.reduce_sum(logs, axis=[1, 2, 3])
            z = tf.concat([z1, z2], 3)
            z = ops.reverse_features("reverse", z)
            #z, logdet = invertible_1x1_conv("invconv", z, logdet, reverse=True)
            z, logdet = ops.scale_bias(
                "actnorm", z, logdet=logdet, reverse=True)
    return z, logdet
コード例 #2
ファイル: models.py プロジェクト: omidsakhi/tpu_glow
    def _f_loss(self, x, y):

        with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
            y_onehot = tf.cast(tf.one_hot(y, self.cfg.n_y, 1, 0), 'float32')

            # Discrete -> Continuous
            objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]

            z = x  # + tf.random_uniform(tf.shape(x), 0, 1./self.cfg.n_bins)
            objective += - np.log(self.cfg.n_bins) * \

            # Encode
            z = ops.squeeze2d(z, 2)  # > 16x16x12
            z, objective, eps = self.encoder(z, objective)

            # Prior
            self.cfg.top_shape = ops.int_shape(z)[1:]
            logp, _, _ = prior("prior", y_onehot, self.cfg)

            objective += logp(z)

            # Generative loss
            nobj = - objective
            bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int(
                x.get_shape()[2]) * int(x.get_shape()[3]))  # bits per subpixel

            # Predictive loss
            if self.cfg.weight_y > 0 and self.cfg.ycond:
                # Classification loss
                h_y = tf.reduce_mean(z, axis=[1, 2])
                y_logits = ops.dense(
                    "classifier", h_y, self.cfg.n_y, has_bn=False)
                bits_y = tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=y_onehot, logits=y_logits) / np.log(2.)

                # Classification accuracy
                y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32)
                classification_error = 1 - \
                    tf.cast(tf.equal(y_predicted, y), tf.float32)
                bits_y = tf.zeros_like(bits_x)
                classification_error = tf.ones_like(bits_x)

        return bits_x, bits_y, classification_error, eps
コード例 #3
ファイル: models.py プロジェクト: omidsakhi/tpu_glow
def checkpoint(z, logdet):
    zshape = ops.int_shape(z)
    z = tf.reshape(z, [-1, zshape[1]*zshape[2]*zshape[3]])
    logdet = tf.reshape(logdet, [-1, 1])
    combined = tf.concat([z, logdet], axis=1)
    tf.add_to_collection('checkpoints', combined)
    logdet = combined[:, -1]
    z = tf.reshape(combined[:, :-1], [-1, zshape[1], zshape[2], zshape[3]])
    return z, logdet
コード例 #4
ファイル: models.py プロジェクト: omidsakhi/tpu_glow
def split2d(name, z, cfg, objective=0.):
    with tf.variable_scope(name):
        n_z = ops.int_shape(z)[3]
        z1 = z[:, :, :, :n_z // 2]
        z2 = z[:, :, :, n_z // 2:]
        pz = split2d_prior(z1, cfg)
        objective += pz.logp(z2)
        z1 = ops.squeeze2d(z1)
        eps = pz.get_eps(z2)
        return z1, objective, eps
コード例 #5
ファイル: models.py プロジェクト: omidsakhi/tpu_glow
    def f_encode(self, x, y):
        with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
            y_onehot = tf.cast(tf.one_hot(y, self.cfg.n_y, 1, 0), 'float32')

            # Discrete -> Continuous
            objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
            z = x + tf.random_uniform(tf.shape(x), 0, 1. / self.cfg.n_bins)
            objective += - np.log(self.cfg.n_bins) * \

            # Encode
            z = ops.squeeze2d(z, 2)  # > 16x16x12
            z, objective, eps = self.encoder(z, objective)

            # Prior
            self.cfg.top_shape = ops.int_shape(z)[1:]
            logp, _, _eps = prior("prior", y_onehot, self.cfg)
            objective += logp(z)
        return eps
コード例 #6
ファイル: models.py プロジェクト: omidsakhi/tpu_glow
def invertible_1x1_conv(name, z, logdet, reverse=False):

    if True:  # Set to "False" to use the LU-decomposed version

        with tf.variable_scope(name):

            shape = ops.int_shape(z)
            C = shape[3]

            w = tf.get_variable("w", shape=(
                C, C), dtype=tf.float32, initializer=tf.initializers.orthogonal())

            dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant(
                tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2]

            if not reverse:
                w = tf.reshape(w, [1, 1, C, C])
                z = tf.nn.conv2d(z, w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet += dlogdet

                return z, logdet

                w = tf.matrix_inverse(w)
                w = tf.reshape(w, [1, 1, C, C])
                z = tf.nn.conv2d(z, w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet -= dlogdet

                return z, logdet

        # LU-decomposed version
        shape = ops.int_shape(z)
        with tf.variable_scope(name):

            dtype = 'float64'

            # Random orthogonal matrix:
            import scipy
            np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[

            np_p, np_l, np_u = scipy.linalg.lu(np_w)  # pylint: disable=E1101
            np_s = np.diag(np_u)
            np_sign_s = np.sign(np_s)
            np_log_s = np.log(abs(np_s))
            np_u = np.triu(np_u, k=1)

            p = tf.get_variable("P", initializer=np_p, trainable=False)
            l = tf.get_variable("L", initializer=np_l)
            sign_s = tf.get_variable(
                "sign_S", initializer=np_sign_s, trainable=False)
            log_s = tf.get_variable("log_S", initializer=np_log_s)
            # S = tf.get_variable("S", initializer=np_s)
            u = tf.get_variable("U", initializer=np_u)

            p = tf.cast(p, dtype)
            l = tf.cast(l, dtype)
            sign_s = tf.cast(sign_s, dtype)
            log_s = tf.cast(log_s, dtype)
            u = tf.cast(u, dtype)

            w_shape = [shape[3], shape[3]]

            l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1)
            l = l * l_mask + tf.eye(*w_shape, dtype=dtype)
            u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
            w = tf.matmul(p, tf.matmul(l, u))

            if True:
                u_inv = tf.matrix_inverse(u)
                l_inv = tf.matrix_inverse(l)
                p_inv = tf.matrix_inverse(p)
                w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv))
                w_inv = tf.matrix_inverse(w)

            w = tf.cast(w, tf.float32)
            w_inv = tf.cast(w_inv, tf.float32)
            log_s = tf.cast(log_s, tf.float32)

            if not reverse:

                w = tf.reshape(w, [1, 1] + w_shape)
                z = tf.nn.conv2d(z, w, [1, 1, 1, 1],
                                 'SAME', data_format='NHWC')
                logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2])

                return z, logdet

                w_inv = tf.reshape(w_inv, [1, 1]+w_shape)
                z = tf.nn.conv2d(
                    z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC')
                logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2])

                return z, logdet