Beispiel #1
0
    def _vlb_in_bits_per_dims(self, model, x_start, x_t, t,
                              clip_denoised=True):
        """
        Calculate variational lower bound in bits/dims.
        """
        B, C, H, W = x_start.shape
        assert x_start.shape == x_t.shape
        assert t.shape == (B, )

        # true parameters
        mean, _, log_var_clipped = self.q_posterior(x_start, x_t, t)

        # pred parameters
        preds = self.p_mean_var(model, x_t, t, clip_denoised=clip_denoised)

        # Negative log-likelihood
        nll = -gaussian_log_likelihood(x_start,
                                       mean=preds.mean, logstd=0.5 * preds.log_var)
        nll_bits = mean_along_except_batch(nll) / np.log(2.0)
        assert nll.shape == x_start.shape
        assert nll_bits.shape == (B, )

        # kl between true and pred in bits
        kl = kl_normal(mean, log_var_clipped, preds.mean, preds.log_var)
        kl_bits = mean_along_except_batch(kl) / np.log(2.0)
        assert kl.shape == x_start.shape
        assert kl_bits.shape == (B, )

        # Return nll at t = 0, otherwise KL(q(x_{t-1}|x_t,x_0)||p(x_{t-1}|x_t))
        return F.where(F.equal_scalar(t, 0), nll_bits, kl_bits)
Beispiel #2
0
def gaussian_log_likelihood(x, mean, logstd, orig_max_val=255):
    """
    Compute the log-likelihood of a Gaussian distribution for given data `x`.

    Args:
        x (nn.Variable): Target data. It is assumed that the values are ranged [-1, 1],
                         which are originally [0, orig_max_val].
        means (nn.Variable): Gaussian mean. Must be the same shape as x.
        logstd (nn.Variable): Gaussian log standard deviation. Must be the same shape as x.
        orig_max_val (int): The maximum value that x originally has before being rescaled.

    Return:
        A log probabilies of x in nats.
    """
    assert x.shape == mean.shape == logstd.shape
    centered_x = x - mean
    inv_std = F.exp(-logstd)
    half_bin = 1.0 / orig_max_val

    def clamp(val):
        # Here we don't need to clip max
        return F.clip_by_value(val, min=1e-12, max=1e8)

    # x + 0.5 (in original scale)
    plus_in = inv_std * (centered_x + half_bin)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    log_cdf_plus = F.log(clamp(cdf_plus))

    # x - 0.5 (in original scale)
    minus_in = inv_std * (centered_x - half_bin)
    cdf_minus = approx_standard_normal_cdf(minus_in)
    log_one_minus_cdf_minus = F.log(clamp(1.0 - cdf_minus))

    log_cdf_delta = F.log(clamp(cdf_plus - cdf_minus))

    log_probs = F.where(
        F.less_scalar(x, -0.999),
        log_cdf_plus,  # Edge case for 0. It uses cdf for -inf as cdf_minus.
        F.where(F.greater_scalar(x, 0.999),
                # Edge case for orig_max_val. It uses cdf for +inf as cdf_plus.
                log_one_minus_cdf_minus,
                log_cdf_delta  # otherwise
                )
    )

    assert log_probs.shape == x.shape
    return log_probs
Beispiel #3
0
def sample_pdf(bins, weights, N_samples, det=False):
    """Sample additional points for training fine network

    Args:
      bins: int. Height in pixels.
      weights: int. Width in pixels.
      N_samples: float. Focal length of pinhole camera.
      det

    Returns:
      samples: array of shape [batch_size, 3]. Depth samples for fine network
    """
    weights += 1e-5
    pdf = weights / F.sum(weights, axis=-1, keepdims=True)

    cdf = F.cumsum(pdf, axis=-1)
    # if isinstance(pdf, nn.Variable):
    #     cdf = nn.Variable.from_numpy_array(tf.math.cumsum(pdf.d, axis=-1))
    # else:
    #     cdf = nn.Variable.from_numpy_array(tf.math.cumsum(pdf.data, axis=-1)).data
    cdf = F.concatenate(F.constant(0, cdf[..., :1].shape), cdf, axis=-1)

    if det:
        u = F.arange(0., 1., 1 / N_samples)
        u = F.broadcast(u[None, :], cdf.shape[:-1] + (N_samples, ))
        u = u.data if isinstance(cdf, nn.NdArray) else u
    else:
        u = F.rand(shape=cdf.shape[:-1] + (N_samples, ))

    indices = F.searchsorted(cdf, u, right=True)
    # if isinstance(cdf, nn.Variable):
    #     indices = nn.Variable.from_numpy_array(
    #         tf.searchsorted(cdf.d, u.d, side='right').numpy())
    # else:
    #     indices = nn.Variable.from_numpy_array(
    #         tf.searchsorted(cdf.data, u.data, side='right').numpy())
    below = F.maximum_scalar(indices - 1, 0)
    above = F.minimum_scalar(indices, cdf.shape[-1] - 1)
    indices_g = F.stack(below, above, axis=below.ndim)
    cdf_g = F.gather(cdf,
                     indices_g,
                     axis=-1,
                     batch_dims=len(indices_g.shape) - 2)
    bins_g = F.gather(bins,
                      indices_g,
                      axis=-1,
                      batch_dims=len(indices_g.shape) - 2)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = F.where(F.less_scalar(denom, 1e-5), F.constant(1, denom.shape),
                    denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
Beispiel #4
0
def loss_dis_real(logits, rec_imgs, part, img, lmd=1.0):
    # loss = 0.0

    # Hinge loss (following the official implementation)
    loss = F.mean(F.relu(0.2*F.rand(shape=logits.shape) + 0.8 - logits))

    # Reconstruction loss for rec_img_big (reconstructed from 8x8 features of the original image)
    # Reconstruction loss for rec_img_small (reconstructed from 8x8 features of the resized image)
    # Reconstruction loss for rec_img_part (reconstructed from a part of 16x16 features of the original image)
    if lmd > 0.0:
        # Ground-truth
        img_128 = F.interpolate(img, output_size=(128, 128))
        img_256 = F.interpolate(img, output_size=(256, 256))

        img_half = F.where(F.greater_scalar(
            part[0], 0.5), img_256[:, :, :128, :], img_256[:, :, 128:, :])
        img_part = F.where(F.greater_scalar(
            part[1], 0.5), img_half[:, :, :, :128], img_half[:, :, :, 128:])

        # Integrated perceptual loss
        loss = loss + lmd * \
            reconstruction_loss_lpips(rec_imgs, [img_128, img_part])

    return loss
Beispiel #5
0
def sinc_backward(inputs):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    x0 = inputs[1]
    m0 = F.not_equal_scalar(x0, 0)
    m0 = no_grad(m0)
    y0 = get_output(x0, "Sinc")

    dx0 = dy * (F.cos(x0) - y0) / x0
    c0 = F.constant(0, x0.shape)
    dx0 = F.where(m0, dx0, c0)
    return dx0
Beispiel #6
0
def where_backward(inputs):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    cd = inputs[1]
    xt = inputs[2]
    xf = inputs[3]
    c1 = F.constant(1, xt.shape)
    c0 = F.constant(0, xf.shape)
    m0 = F.where(cd, c1, c0)
    m1 = 1 - m0
    m0 = no_grad(m0)
    m1 = no_grad(m1)
    dx0 = dy * m0
    dx1 = dy * m1
    return None, dx0, dx1
Beispiel #7
0
def augment(batch, aug_list, p_aug=1.0):

    if isinstance(p_aug, float):
        p_aug = nn.Variable.from_numpy_array(p_aug * np.ones((1,)))

    if "flip" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        batch_aug = F.random_flip(batch, axes=(2, 3))
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "lrflip" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        batch_aug = F.random_flip(batch, axes=(3,))
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "translation" in aug_list and batch.shape[2] >= 8:
        rnd = F.rand(shape=[batch.shape[0], ])
        # Currently nnabla does not support random_shift with border_mode="noise"
        mask = np.ones((1, 3, batch.shape[2], batch.shape[3]))
        mask[:, :, :, 0] = 0
        mask[:, :, :, -1] = 0
        mask[:, :, 0, :] = 0
        mask[:, :, -1, :] = 0
        batch_int = F.concatenate(
            batch, nn.Variable().from_numpy_array(mask), axis=0)
        batch_int_aug = F.random_shift(batch_int, shifts=(
            batch.shape[2]//8, batch.shape[3]//8), border_mode="nearest")
        batch_aug = F.slice(batch_int_aug, start=(
            0, 0, 0, 0), stop=batch.shape)
        mask_var = F.slice(batch_int_aug, start=(
            batch.shape[0], 0, 0, 0), stop=batch_int_aug.shape)
        batch_aug = batch_aug * F.broadcast(mask_var, batch_aug.shape)
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "color" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        rnd_contrast = 1.0 + 0.5 * \
            (2.0 * F.rand(shape=[batch.shape[0], 1, 1, 1]
                          ) - 1.0)  # from 0.5 to 1.5
        rnd_brightness = 0.5 * \
            (2.0 * F.rand(shape=[batch.shape[0], 1, 1, 1]
                          ) - 1.0)  # from -0.5 to 0.5
        rnd_saturation = 2.0 * \
            F.rand(shape=[batch.shape[0], 1, 1, 1])  # from 0.0 to 2.0
        # Brightness
        batch_aug = batch + rnd_brightness
        # Saturation
        mean_s = F.mean(batch_aug, axis=1, keepdims=True)
        batch_aug = rnd_saturation * (batch_aug - mean_s) + mean_s
        # Contrast
        mean_c = F.mean(batch_aug, axis=(1, 2, 3), keepdims=True)
        batch_aug = rnd_contrast * (batch_aug - mean_c) + mean_c
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "cutout" in aug_list and batch.shape[2] >= 16:
        batch = F.random_erase(batch, prob=p_aug.d[0], replacements=(0.0, 0.0))

    return batch
Beispiel #8
0
def projection(x: nn.NdArray, eps: float = 1e-5) -> nn.NdArray:
    norm = F.pow_scalar(F.sum(x**2, axis=1), val=0.5)
    return F.where(condition=F.greater_equal_scalar(norm, val=1.),
                   x_true=F.clip_by_norm(x, clip_norm=1 - eps, axis=1),
                   x_false=x)
Beispiel #9
0
def Discriminator(img, label="real", scope_name="Discriminator", ndf=64):
    with nn.parameter_scope(scope_name):
        if type(img) is not list:
            img_small = F.interpolate(img, output_size=(128, 128))
        else:
            img_small = img[1]
            img = img[0]

        def sn_w(w):
            return PF.spectral_norm(w, dim=0)

        # InitLayer: -> 256x256
        with nn.parameter_scope("init"):
            h = img
            if img.shape[2] == 1024:
                h = PF.convolution(h,
                                   ndf // 8, (4, 4),
                                   stride=(2, 2),
                                   pad=(1, 1),
                                   apply_w=sn_w,
                                   with_bias=False,
                                   name="conv1")
                h = F.leaky_relu(h, 0.2)
                h = PF.convolution(h,
                                   ndf // 4, (4, 4),
                                   stride=(2, 2),
                                   pad=(1, 1),
                                   apply_w=sn_w,
                                   with_bias=False,
                                   name="conv2")
                h = PF.batch_normalization(h)
                h = F.leaky_relu(h, 0.2)
            elif img.shape[2] == 512:
                h = PF.convolution(h,
                                   ndf // 4, (4, 4),
                                   stride=(2, 2),
                                   pad=(1, 1),
                                   apply_w=sn_w,
                                   with_bias=False,
                                   name="conv2")
                h = F.leaky_relu(h, 0.2)
            else:
                h = PF.convolution(h,
                                   ndf // 4, (3, 3),
                                   pad=(1, 1),
                                   apply_w=sn_w,
                                   with_bias=False,
                                   name="conv3")
                h = F.leaky_relu(h, 0.2)

        # Calc base features
        f_256 = h
        f_128 = DownsampleComp(f_256, ndf // 2, "down256->128")
        f_64 = DownsampleComp(f_128, ndf * 1, "down128->64")
        f_32 = DownsampleComp(f_64, ndf * 2, "down64->32")

        # Apply SLE
        f_32 = SLE(f_32, f_256, "sle256->32")
        f_16 = DownsampleComp(f_32, ndf * 4, "down32->16")
        f_16 = SLE(f_16, f_128, "sle128->16")
        f_8 = DownsampleComp(f_16, ndf * 16, "down16->8")
        f_8 = SLE(f_8, f_64, "sle64->8")

        # Conv + BN + LeakyRely + Conv -> logits (5x5)
        with nn.parameter_scope("last"):
            h = PF.convolution(f_8,
                               ndf * 16, (1, 1),
                               apply_w=sn_w,
                               with_bias=False,
                               name="conv1")
            h = PF.batch_normalization(h)
            h = F.leaky_relu(h, 0.2)
            logit_large = PF.convolution(h,
                                         1, (4, 4),
                                         apply_w=sn_w,
                                         with_bias=False,
                                         name="conv2")

        # Another path: "down_from_small" in the official code
        with nn.parameter_scope("down_from_small"):
            h_s = PF.convolution(img_small,
                                 ndf // 2, (4, 4),
                                 stride=(2, 2),
                                 pad=(1, 1),
                                 apply_w=sn_w,
                                 with_bias=False,
                                 name="conv1")
            h_s = F.leaky_relu(h_s, 0.2)
            h_s = Downsample(h_s, ndf * 1, "dfs64->32")
            h_s = Downsample(h_s, ndf * 2, "dfs32->16")
            h_s = Downsample(h_s, ndf * 4, "dfs16->8")
            fea_dec_small = h_s
            logit_small = PF.convolution(h_s,
                                         1, (4, 4),
                                         apply_w=sn_w,
                                         with_bias=False,
                                         name="conv2")

        # Concatenate logits
        logits = F.concatenate(logit_large, logit_small, axis=1)

        # Reconstruct images
        rec_img_big = SimpleDecoder(f_8, "dec_big")
        rec_img_small = SimpleDecoder(fea_dec_small, "dec_small")
        part_ax2 = F.rand(shape=(img.shape[0], ))
        part_ax3 = F.rand(shape=(img.shape[0], ))
        f_16_ax2 = F.where(F.greater_scalar(part_ax2, 0.5), f_16[:, :, :8, :],
                           f_16[:, :, 8:, :])
        f_16_part = F.where(F.greater_scalar(part_ax3, 0.5),
                            f_16_ax2[:, :, :, :8], f_16_ax2[:, :, :, 8:])
        rec_img_part = SimpleDecoder(f_16_part, "dec_part")

    if label == "real":
        return logits, [rec_img_big, rec_img_small,
                        rec_img_part], [part_ax2, part_ax3]
    else:
        return logits