示例#1
0
def deconv2d(inputs, num_filters, kernel_size, strides=1, padding='SAME', nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}):
#def deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
    filter_size = kernel_size
    pad = padding
    x = inputs
    stride = strides
    xs = int_shape(x)
    name = get_name('deconv2d', counters)
    if pad=='SAME':
        target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters]
    else:
        target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters]
    with tf.variable_scope(name):
        V = tf.get_variable('V', shape=filter_size+[num_filters,int(x.get_shape()[-1])], dtype=tf.float32,
                              initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
        g = tf.get_variable('g', shape=[num_filters], dtype=tf.float32,
                              initializer=tf.constant_initializer(1.), trainable=True)
        b = tf.get_variable('b', shape=[num_filters], dtype=tf.float32,
                              initializer=tf.constant_initializer(0.), trainable=True)

        # use weight normalization (Salimans & Kingma, 2016)
        W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3])

        # calculate convolutional layer output
        x = tf.nn.conv2d_transpose(x, W, target_shape, [1] + stride + [1], padding=pad)
        x = tf.nn.bias_add(x, b)

        outputs = x

        if bn:
            outputs = tf.layers.batch_normalization(outputs, training=is_training)
        if nonlinearity is not None:
            outputs = nonlinearity(outputs)
        print("    + deconv2d", int_shape(inputs), int_shape(outputs), nonlinearity, bn)
        return outputs
示例#2
0
def deconv2d(inputs,
             num_filters,
             kernel_size,
             strides=1,
             padding='SAME',
             nonlinearity=None,
             bn=True,
             kernel_initializer=None,
             kernel_regularizer=None,
             is_training=False,
             counters={}):
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size, kernel_size]
    if isinstance(strides, int):
        strides = [strides, strides]
    outputs = deconv2d_openai(inputs,
                              num_filters,
                              filter_size=kernel_size,
                              stride=strides,
                              pad=padding,
                              counters=counters,
                              kernel_initializer=kernel_initializer,
                              kernel_regularizer=kernel_regularizer)
    if bn:
        outputs = tf.layers.batch_normalization(outputs, training=is_training)
    if nonlinearity is not None:
        outputs = nonlinearity(outputs)
    print("    + deconv2d", int_shape(inputs), int_shape(outputs),
          nonlinearity, bn)
    return outputs
 def __model(self, x, is_training):
     print("******   Building Graph   ******")
     self.x = x
     self.is_training = is_training
     if int_shape(x)[1] == 64:
         encoder = conv_encoder_64_medium
         decoder = conv_decoder_64_medium
     elif int_shape(x)[1] == 32:
         encoder = conv_encoder_32_medium
         decoder = conv_decoder_32_medium
     with arg_scope([encoder, decoder],
                    nonlinearity=self.nonlinearity,
                    bn=self.bn,
                    kernel_initializer=self.kernel_initializer,
                    kernel_regularizer=self.kernel_regularizer,
                    is_training=self.is_training,
                    counters=self.counters):
         self.z_mu, self.z_log_sigma_sq = encoder(x, self.z_dim)
         sigma = tf.exp(self.z_log_sigma_sq / 2.)
         if self.use_mode == 'train':
             self.z = gaussian_sampler(self.z_mu, sigma)
         elif self.use_mode == 'test':
             self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu))
         print("use mode:{0}".format(self.use_mode))
         self.x_hat = decoder(self.z)
示例#4
0
 def __model(self,
             x,
             x_bar,
             is_training,
             dropout_p,
             masks,
             input_masks,
             network_size="medium"):
     print("******   Building Graph   ******")
     self.x = x
     self.x_bar = x_bar
     self.is_training = is_training
     self.dropout_p = dropout_p
     self.masks = masks
     self.input_masks = input_masks
     if int_shape(x)[1] == 64:
         conv_encoder = conv_encoder_64_medium
         conv_decoder = conv_decoder_64_medium
     elif int_shape(x)[1] == 32:
         if network_size == 'medium':
             conv_encoder = conv_encoder_32_medium
             conv_decoder = conv_decoder_32_medium
         elif network_size == 'large':
             conv_encoder = conv_encoder_32_large
             conv_decoder = conv_decoder_32_large
         elif network_size == 'large1':
             conv_encoder = conv_encoder_32_large1
             conv_decoder = conv_decoder_32_large1
         else:
             raise Exception("unknown network type")
     with arg_scope([conv_encoder, conv_decoder, cond_pixel_cnn],
                    nonlinearity=self.nonlinearity,
                    bn=self.bn,
                    kernel_initializer=self.kernel_initializer,
                    kernel_regularizer=self.kernel_regularizer,
                    is_training=self.is_training,
                    counters=self.counters):
         inputs = self.x
         self.z_mu, self.z_log_sigma_sq = conv_encoder(inputs, self.z_dim)
         sigma = tf.exp(self.z_log_sigma_sq / 2.)
         if self.use_mode == 'train':
             self.z = gaussian_sampler(self.z_mu, sigma)
         elif self.use_mode == 'test':
             self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu))
         print("use mode:{0}".format(self.use_mode))
         self.decoded_features = conv_decoder(self.z, output_features=True)
         sh = self.decoded_features
         self.mix_logistic_params = cond_pixel_cnn(
             self.x_bar,
             sh=sh,
             bn=False,
             dropout_p=self.dropout_p,
             nr_resnet=self.nr_resnet,
             nr_filters=self.nr_filters,
             nr_logistic_mix=self.nr_logistic_mix)
         self.x_hat = mix_logistic_sampler(
             self.mix_logistic_params,
             nr_logistic_mix=self.nr_logistic_mix,
             sample_range=self.sample_range,
             counters=self.counters)
示例#5
0
def conv2d(inputs, num_filters, kernel_size, strides=1, padding='SAME', nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False):
    outputs = tf.layers.conv2d(inputs, num_filters, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer)
    if nonlinearity is not None:
        outputs = nonlinearity(outputs)
    if bn:
        outputs = tf.layers.batch_normalization(outputs, training=is_training)
    print("    + conv2d", int_shape(inputs), int_shape(outputs), nonlinearity, bn)
    return outputs
示例#6
0
def gaussian_sampler(loc, scale, counters={}):
    name = get_name("gaussian_sampler", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        dist = tf.distributions.Normal(loc=0., scale=1.)
        z = dist.sample(sample_shape=int_shape(loc), seed=None)
        z = loc + tf.multiply(z, scale)
        print("    + gaussian_sampler", int_shape(z))
        return z
示例#7
0
def bernoulli_loss(x, l, sum_all=True):
    xs = int_shape(x)
    ls = int_shape(l)
    lse = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=l), 3)
    if sum_all:
        return tf.reduce_sum(lse)
    else:
        return tf.reduce_sum(lse, [1, 2])
示例#8
0
def dense(inputs, num_outputs, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False):
    inputs_shape = int_shape(inputs)
    assert len(inputs_shape)==2, "inputs should be flattened first"
    outputs = tf.layers.dense(inputs, num_outputs, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer)
    if nonlinearity is not None:
        outputs = nonlinearity(outputs)
    if bn:
        outputs = tf.layers.batch_normalization(outputs, training=is_training)
    print("    + dense", int_shape(inputs), int_shape(outputs), nonlinearity, bn)
    return outputs
def bernoulli_loss(x, l, masks=None, output_mean=True):
    xs = int_shape(x)
    ls = int_shape(l)

    lse = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=l), 3)
    if masks is not None:
        assert lse.shape == masks.shape, "shape of masks does not match the log_sum_exp outputs"
        lse *= (1 - masks)
    if output_mean:
        return tf.reduce_sum(lse)
    else:
        return tf.reduce_sum(lse, [1, 2])
def discretized_mix_logistic_loss(x,l,sum_all=True, masks=None):
    """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
    xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
    ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)
    nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
    logit_probs = l[:,:,:,:nr_mix]
    l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])
    means = l[:,:,:,:,:nr_mix]
    log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)
    coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])

    x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels
    m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])
    m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])
    means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3)
    centered_x = x - means
    inv_stdv = tf.exp(-log_scales)

    plus_in = inv_stdv * (centered_x + 1./255.)
    cdf_plus = tf.nn.sigmoid(plus_in)
    min_in = inv_stdv * (centered_x - 1./255.)
    cdf_min = tf.nn.sigmoid(min_in)
    log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
    log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
    cdf_delta = cdf_plus - cdf_min # probability for all other cases
    mid_in = inv_stdv * centered_x
    log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

    # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us)

    # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()
    # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta)))

    # robust version, that still works if probabilities are below 1e-5 (which never happens in our code)
    # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs
    # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue
    # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value
    log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5))))

    log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)

    lse = log_sum_exp(log_probs)

    if masks is not None:
        assert lse.shape==masks.shape, "shape of masks does not match the log_sum_exp outputs"
        lse *= (1 - masks)
    if sum_all:
        return -tf.reduce_sum(lse)
    else:
        return -tf.reduce_sum(lse,[1,2])
def conditioning_network_28(x,
                            masks,
                            nr_filters,
                            is_training=True,
                            nonlinearity=None,
                            bn=True,
                            kernel_initializer=None,
                            kernel_regularizer=None,
                            counters={}):
    name = get_name("conditioning_network_28", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    xs = int_shape(x)
    x = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)
    with tf.variable_scope(name):
        with arg_scope([conv2d, residual_block, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            outputs = conv2d(x, nr_filters, 4, 1, "SAME")
            for l in range(4):
                outputs = conv2d(outputs, nr_filters, 4, 1, "SAME")
            outputs = conv2d(outputs,
                             nr_filters,
                             1,
                             1,
                             "SAME",
                             nonlinearity=None,
                             bn=False)
            return outputs
示例#12
0
def conditional_decoder(x,
                        z,
                        nonlinearity=None,
                        bn=True,
                        kernel_initializer=None,
                        kernel_regularizer=None,
                        is_training=False,
                        counters={}):
    name = get_name("conditional_decoder", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training):
            size = 256
            batch_size = tf.shape(x)[0]
            x = tf.tile(x, tf.stack([1, int_shape(z)[1]]))
            z = tf.tile(z, tf.stack([batch_size, 1]))
            # xz = x + z * tf.get_variable(name="coeff", shape=(), dtype=tf.float32, initializer=tf.constant_initializer(2.0))
            xz = x
            a = dense(xz, size, nonlinearity=None) + dense(
                z, size, nonlinearity=None)
            outputs = tf.nn.tanh(a) * tf.sigmoid(a)

            for k in range(4):
                a = dense(outputs, size, nonlinearity=None) + dense(
                    z, size, nonlinearity=None)
                outputs = tf.nn.tanh(a) * tf.sigmoid(a)
            outputs = dense(outputs, 1, nonlinearity=None, bn=False)
            outputs = tf.reshape(outputs, shape=(batch_size, ))
            return outputs
def estimate_mi_tc_dwkld(z, z_mu, z_log_sigma_sq, N=2e5):
    # computational cheaper to compute them together
    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
    log_probs = []
    batch_size, z_dim = int_shape(z_mu)

    z_b = tf.stack([z for i in range(batch_size)], axis=0)
    z_mu_b = tf.stack([z_mu for i in range(batch_size)], axis=1)
    z_sigma_b = tf.stack([z_sigma for i in range(batch_size)], axis=1)
    z_norm = (z_b - z_mu_b) / z_sigma_b

    dist = tf.distributions.Normal(loc=0., scale=1.)
    log_probs = dist.log_prob(z_norm)
    ratio = np.log(float(N - 1) / (batch_size - 1)) * np.ones(
        (batch_size, batch_size))
    np.fill_diagonal(ratio, 0.)
    ratio_b = np.stack([ratio for i in range(z_dim)], axis=-1)

    lse_sum = tf.reduce_mean(
        log_sum_exp(tf.reduce_sum(log_probs, axis=-1) + ratio, axis=0))
    sum_lse = tf.reduce_mean(
        tf.reduce_sum(log_sum_exp(log_probs + ratio_b, axis=0), axis=-1))
    lse_sum -= tf.log(float(N))
    sum_lse -= tf.log(float(N)) * float(z_dim)

    kld = compute_gaussian_kld(z_mu, z_log_sigma_sq)
    cond_entropy = tf.reduce_mean(
        compute_gaussian_entropy(z_mu, z_log_sigma_sq))

    mi = -lse_sum - cond_entropy
    tc = lse_sum - sum_lse
    dwkld = sum_lse + (kld + cond_entropy)

    return mi, tc, dwkld
示例#14
0
def gated_resnet(x,
                 a=None,
                 gh=None,
                 sh=None,
                 nonlinearity=tf.nn.elu,
                 conv=conv2d,
                 dropout_p=0.0,
                 counters={},
                 **kwargs):
    name = get_name("gated_resnet", counters)
    print("construct", name, "...")
    xs = int_shape(x)
    num_filters = xs[-1]
    kwargs["counters"] = counters
    with arg_scope([conv], **kwargs):
        c1 = conv(nonlinearity(x), num_filters)
        if a is not None:  # add short-cut connection if auxiliary input 'a' is given
            c1 += nin(nonlinearity(a), num_filters)
        c1 = nonlinearity(c1)
        c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
        c2 = conv(c1, num_filters * 2)
        # add projection of h vector if included: conditional generation
        if sh is not None:
            c2 += nin(sh, 2 * num_filters, nonlinearity=nonlinearity)
        if gh is not None:  # haven't finished this part
            pass
        a, b = tf.split(c2, 2, 3)
        c3 = a * tf.nn.sigmoid(b)
        return x + c3
def estimate_mi(z, z_mu, z_log_sigma_sq, N=200000):
    batch_size, z_dim = int_shape(z_mu)
    lse_sum, sum_lse = estimate_log_probs(z, z_mu, z_log_sigma_sq, N=N)
    lse_sum -= tf.log(float(N))
    cond_entropy = tf.reduce_mean(
        compute_gaussian_entropy(z_mu, z_log_sigma_sq))
    return -lse_sum - cond_entropy
示例#16
0
def reverse_pixel_cnn_28_binary(x,
                                masks,
                                context=None,
                                nr_logistic_mix=10,
                                nr_resnet=1,
                                nr_filters=100,
                                dropout_p=0.0,
                                nonlinearity=None,
                                bn=True,
                                kernel_initializer=None,
                                kernel_regularizer=None,
                                is_training=False,
                                counters={}):
    name = get_name("reverse_pixel_cnn_28_binary", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=None,
                       sh=context,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p):
            with arg_scope([
                    gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d,
                    up_shifted_deconv2d, up_left_shifted_deconv2d
            ],
                           bn=bn,
                           kernel_initializer=kernel_initializer,
                           kernel_regularizer=kernel_regularizer,
                           is_training=is_training,
                           counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))

                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                return x_out
    def __loss(self, reg):
        print("******   Compute Loss   ******")
        self.mmd, self.kld, self.mi, self.tc, self.dwkld = [
            None for i in range(5)
        ]
        self.gamma, self.dwmmd = 1e3, None  ## hard coded, experimental
        self.mmdtc = None
        self.loss_ae = mix_logistic_loss(self.x,
                                         self.mix_logistic_params,
                                         masks=self.masks)
        if reg is None:
            self.loss_reg = 0
        elif reg == 'kld':
            self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq)
            self.loss_reg = self.beta * tf.maximum(self.lam, self.kld)
        elif reg == 'mmd':
            # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z)
            self.mmd = estimate_mmd(
                tf.random_normal(tf.stack([256, self.z_dim])), self.z)
            self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd)
        elif reg == 'tc':
            self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
                self.z, self.z_mu, self.z_log_sigma_sq, N=self.N)
            self.loss_reg = self.mi + self.beta * self.tc + self.dwkld
        elif reg == 'info-tc':
            self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
                self.z, self.z_mu, self.z_log_sigma_sq, N=self.N)
            self.loss_reg = self.beta * self.tc + self.dwkld
        elif reg == 'tc-dwmmd':
            self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
                self.z, self.z_mu, self.z_log_sigma_sq, N=self.N)
            self.dwmmd = estimate_mmd(tf.random_normal(int_shape(self.z)),
                                      self.z,
                                      is_dimention_wise=True)
            self.loss_reg = self.beta * self.tc + self.dwmmd * self.gamma
        elif reg == 'mmd-tc':
            self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)),
                                    self.z)
            self.mmdtc = estimate_mmdtc(self.z, self.random_indices)
            self.loss_reg = (self.mmd + self.beta * self.mmdtc) * 1e5

        self.mi = estimate_mi(self.z, self.z_mu, self.z_log_sigma_sq, N=200000)

        print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta,
                                                  self.lam))
        self.loss = self.loss_ae + self.loss_reg
示例#18
0
def cond_pixel_cnn(x,
                   gh=None,
                   sh=None,
                   nonlinearity=tf.nn.elu,
                   nr_resnet=5,
                   nr_filters=100,
                   nr_logistic_mix=10,
                   bn=False,
                   dropout_p=0.0,
                   kernel_initializer=None,
                   kernel_regularizer=None,
                   is_training=False,
                   counters={}):
    name = get_name("conv_pixel_cnn", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=gh,
                       sh=sh,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p,
                       counters=counters):
            with arg_scope(
                [gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d],
                    bn=bn,
                    kernel_initializer=kernel_initializer,
                    kernel_regularizer=kernel_regularizer,
                    is_training=is_training):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    down_shift(
                        down_shifted_conv2d(x_pad,
                                            num_filters=nr_filters,
                                            filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=down_right_shifted_conv2d))
                    receptive_field = (receptive_field[0] + 1,
                                       receptive_field[1] + 2)
                x_out = nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix)
                print("    * receptive_field", receptive_field)
                return x_out
def estimate_mmdtc(y, indices):
    batch_size, z_dim = int_shape(y)
    batch_size = batch_size // 2
    x, y = y[:batch_size], y[batch_size:]
    ys = tf.unstack(y, axis=1)
    for i in range(z_dim):
        ys[i] = tf.gather(ys[i], indices[:, i])
    y = tf.stack(ys, axis=1)
    return estimate_mmd(x, y)
def estimate_dwkld(z, z_mu, z_log_sigma_sq, N=200000):
    batch_size, z_dim = int_shape(z_mu)
    lse_sum, sum_lse = estimate_log_probs(z, z_mu, z_log_sigma_sq, N=N)
    sum_lse -= tf.log(float(N)) * float(z_dim)

    kld = compute_gaussian_kld(z_mu, z_log_sigma_sq)
    cond_entropy = tf.reduce_mean(
        compute_gaussian_entropy(z_mu, z_log_sigma_sq))
    nll_prior = kld + cond_entropy

    return sum_lse + nll_prior
def context_encoder(contexts,
                    masks,
                    is_training,
                    nr_resnet=5,
                    nr_filters=100,
                    nonlinearity=None,
                    bn=False,
                    kernel_initializer=None,
                    kernel_regularizer=None,
                    counters={}):
    name = get_name("context_encoder", counters)
    print("construct", name, "...")
    x = contexts * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    if bn:
        print("*** Attention *** using bn in the context encoder\n")
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       nonlinearity=nonlinearity,
                       counters=counters):
            with arg_scope(
                [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d],
                    bn=bn,
                    kernel_initializer=kernel_initializer,
                    kernel_regularizer=kernel_regularizer,
                    is_training=is_training):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))
                    receptive_field = (receptive_field[0] + 1,
                                       receptive_field[1] + 2)
                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                print("    * receptive_field", receptive_field)
                return x_out
示例#22
0
def dense(inputs,
          num_outputs,
          W=None,
          b=None,
          nonlinearity=None,
          bn=False,
          kernel_initializer=None,
          kernel_regularizer=None,
          is_training=False,
          counters={}):
    ''' fully connected layer '''
    name = get_name('dense', counters)
    with tf.variable_scope(name):
        if W is None:
            W = tf.get_variable(
                'W',
                shape=[int(inputs.get_shape()[1]), num_outputs],
                dtype=tf.float32,
                trainable=True,
                initializer=kernel_initializer,
                regularizer=kernel_regularizer)
        if b is None:
            b = tf.get_variable('b',
                                shape=[num_outputs],
                                dtype=tf.float32,
                                trainable=True,
                                initializer=tf.constant_initializer(0.),
                                regularizer=None)

        outputs = tf.matmul(inputs, W) + tf.reshape(b, [1, num_outputs])

        if bn:
            outputs = tf.layers.batch_normalization(outputs,
                                                    training=is_training)
        if nonlinearity is not None:
            outputs = nonlinearity(outputs)
        print("    + dense", int_shape(inputs), int_shape(outputs),
              nonlinearity, bn)
        return outputs
示例#23
0
def up_left_shifted_deconv2d(x,
                             num_filters,
                             filter_size=[2, 2],
                             strides=[1, 1],
                             **kwargs):
    x = deconv2d(x,
                 num_filters,
                 kernel_size=filter_size,
                 strides=strides,
                 padding='VALID',
                 **kwargs)
    xs = int_shape(x)
    return x[:, (xs[1] - filter_size[0] + 1):,
             (xs[2] - filter_size[1] + 1):, :]
示例#24
0
def down_shifted_deconv2d(x,
                          num_filters,
                          filter_size=[2, 3],
                          strides=[1, 1],
                          **kwargs):
    x = deconv2d(x,
                 num_filters,
                 kernel_size=filter_size,
                 strides=strides,
                 padding='VALID',
                 **kwargs)
    xs = int_shape(x)
    return x[:, :(xs[1] - filter_size[0] + 1),
             int((filter_size[1] - 1) / 2):(xs[2] -
                                            int((filter_size[1] - 1) / 2)), :]
示例#25
0
def deconv2d_openai(x,
                    num_filters,
                    filter_size=[3, 3],
                    stride=[1, 1],
                    pad='SAME',
                    nonlinearity=None,
                    kernel_initializer=None,
                    init_scale=1.,
                    counters={},
                    init=False,
                    ema=None,
                    **kwargs):
    name = get_name("deconv2d", counters)
    xs = int_shape(x)
    if pad == 'SAME':
        target_shape = [
            xs[0], xs[1] * stride[0], xs[2] * stride[1], num_filters
        ]
    else:
        target_shape = [
            xs[0], xs[1] * stride[0] + filter_size[0] - 1,
            xs[2] * stride[1] + filter_size[1] - 1, num_filters
        ]
    with tf.variable_scope(name):
        V = tf.get_variable('V',
                            shape=filter_size +
                            [num_filters, int(x.get_shape()[-1])],
                            dtype=tf.float32,
                            initializer=kernel_initializer,
                            trainable=True)
        b = tf.get_variable('b',
                            shape=[num_filters],
                            dtype=tf.float32,
                            initializer=tf.constant_initializer(0.),
                            trainable=True)
        W = V
        x = tf.nn.conv2d_transpose(x,
                                   W,
                                   target_shape, [1] + stride + [1],
                                   padding=pad)
        x = tf.nn.bias_add(x, b)
        return x
示例#26
0
def sample_from_discretized_mix_logistic(l, nr_mix, epsilon=1e-5):
    ls = int_shape(l)
    xs = ls[:-1] + [3]
    # unpack parameters
    logit_probs = l[:, :, :, :nr_mix]
    l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3])
    # sample mixture indicator from softmax
    sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32)
    sel = tf.reshape(sel, xs[:-1] + [1,nr_mix])
    # select logistic parameters
    means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4)
    log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.)
    coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4)
    # sample from logistic & clip to interval
    # we don't actually round to the nearest 8bit value when sampling
    u = tf.random_uniform(means.get_shape(), minval=epsilon, maxval=1. - epsilon)
    x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u))
    x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.)
    x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.)
    x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.)
    return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3)
 def __loss(self, reg):
     print("******   Compute Loss   ******")
     self.mmd, self.kld, self.mi, self.tc, self.dwkld = [
         None for i in range(5)
     ]
     self.loss_ae = gaussian_recons_loss(self.x, self.x_hat)
     if reg is None:
         self.loss_reg = 0
     elif reg == 'kld':
         self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq)
         self.loss_reg = self.beta * tf.maximum(self.lam, self.kld)
     elif reg == 'mmd':
         self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)),
                                 self.z)
         self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd)
     elif reg == 'tc':
         self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
             self.z, self.z_mu, self.z_log_sigma_sq, N=self.N)
         self.loss_reg = self.mi + self.beta * self.tc + self.dwkld
     print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta,
                                               self.lam))
     self.loss = self.loss_ae + self.loss_reg
def estimate_tc(z, z_mu, z_log_sigma_sq, N=200000):
    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
    log_probs = []
    batch_size, z_dim = int_shape(z_mu)

    z_b = tf.stack([z for i in range(batch_size)], axis=0)
    z_mu_b = tf.stack([z_mu for i in range(batch_size)], axis=1)
    z_sigma_b = tf.stack([z_sigma for i in range(batch_size)], axis=1)
    z_norm = (z_b - z_mu_b) / z_sigma_b

    dist = tf.distributions.Normal(loc=0., scale=1.)
    log_probs = dist.log_prob(z_norm)
    ratio = np.log(float(N - 1) / (batch_size - 1)) * np.ones(
        (batch_size, batch_size))
    np.fill_diagonal(ratio, 0.)
    ratio_b = np.stack([ratio for i in range(z_dim)], axis=-1)

    lse_sum = tf.reduce_mean(
        log_sum_exp(tf.reduce_sum(log_probs, axis=-1) + ratio, axis=0))
    sum_lse = tf.reduce_mean(
        tf.reduce_sum(log_sum_exp(log_probs + ratio_b, axis=0), axis=-1))
    return lse_sum - sum_lse + tf.log(float(N)) * (float(z_dim) - 1)
    def __model(self, network_type="large"):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            num_channels = 1
        else:
            num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        self.use_prior = tf.placeholder_with_default(False, shape=())
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large
                decoder = conv_decoder_32_large
            else:
                encoder = conv_encoder_32
                decoder = conv_decoder_32
            forward_pixelcnn = forward_pixel_cnn_32_small
            reverse_pixelcnn = reverse_pixel_cnn_32_small
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder],
                       **kwargs):
            kwargs_pixelcnn = {
                "nr_resnet": self.nr_resnet,
                "nr_filters": self.nr_filters,
                "nr_logistic_mix": self.nr_logistic_mix,
                "dropout_p": self.dropout_p,
                "bn": False,
            }
            with arg_scope([forward_pixelcnn, reverse_pixelcnn],
                           **kwargs_pixelcnn):
                inputs = self.x
                if self.input_masks is not None:
                    inputs = inputs * broadcast_masks_tf(self.input_masks,
                                                         num_channels=3)
                    inputs += tf.random_uniform(int_shape(inputs), -1, 1) * (
                        1 -
                        broadcast_masks_tf(self.input_masks, num_channels=3))
                    inputs = tf.concat([
                        inputs,
                        broadcast_masks_tf(self.input_masks, num_channels=1)
                    ],
                                       axis=-1)

                self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim)
                sigma = tf.exp(self.z_log_sigma_sq / 2.)
                self.z = gaussian_sampler(self.z_mu, sigma)
                self.z_pr = gaussian_sampler(tf.zeros_like(self.z_mu),
                                             tf.ones_like(sigma))

                self.z_ph = tf.placeholder_with_default(
                    tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu))
                self.use_z_ph = tf.placeholder_with_default(False, shape=())

                use_prior = tf.cast(tf.cast(self.use_prior, tf.int32),
                                    tf.float32)
                use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32),
                                   tf.float32)
                z = (use_prior * self.z_pr + (1 - use_prior) * self.z) * (
                    1 - use_z_ph) + use_z_ph * self.z_ph

                decoded_features = decoder(z, output_features=True)
                r_outputs = reverse_pixelcnn(self.x_bar, self.masks, bn=False)
                cond_features = tf.concat([r_outputs, decoded_features],
                                          axis=-1)
                self.mix_logistic_params = forward_pixelcnn(self.x_bar,
                                                            cond_features,
                                                            bn=False)
                self.x_hat = mix_logistic_sampler(
                    self.mix_logistic_params,
                    nr_logistic_mix=self.nr_logistic_mix,
                    sample_range=self.sample_range,
                    counters=self.counters)
def forward_pixel_cnn_32(x, context, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}):
    name = get_name("forward_pixel_cnn_32", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet], gh=None, sh=None, nonlinearity=nonlinearity, dropout_p=dropout_p):
            with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on

                u_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], sh=context, conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], sh=context, conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                # /////// down pass ////////

                u = u_list.pop()
                ul = ul_list.pop()

                for rep in range(nr_resnet):
                    u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet+1):
                    u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2])


                for rep in range(nr_resnet+1):
                    u = gated_resnet(u, u_list.pop(), sh=None, conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), sh=None, conv=down_right_shifted_conv2d)

                x_out = nin(tf.nn.elu(ul),10*nr_logistic_mix)
                assert len(u_list) == 0
                assert len(ul_list) == 0
                return x_out