def _vlb_in_bits_per_dims(self, model, x_start, x_t, t, clip_denoised=True): """ Calculate variational lower bound in bits/dims. """ B, C, H, W = x_start.shape assert x_start.shape == x_t.shape assert t.shape == (B, ) # true parameters mean, _, log_var_clipped = self.q_posterior(x_start, x_t, t) # pred parameters preds = self.p_mean_var(model, x_t, t, clip_denoised=clip_denoised) # Negative log-likelihood nll = -gaussian_log_likelihood(x_start, mean=preds.mean, logstd=0.5 * preds.log_var) nll_bits = mean_along_except_batch(nll) / np.log(2.0) assert nll.shape == x_start.shape assert nll_bits.shape == (B, ) # kl between true and pred in bits kl = kl_normal(mean, log_var_clipped, preds.mean, preds.log_var) kl_bits = mean_along_except_batch(kl) / np.log(2.0) assert kl.shape == x_start.shape assert kl_bits.shape == (B, ) # Return nll at t = 0, otherwise KL(q(x_{t-1}|x_t,x_0)||p(x_{t-1}|x_t)) return F.where(F.equal_scalar(t, 0), nll_bits, kl_bits)
def gaussian_log_likelihood(x, mean, logstd, orig_max_val=255): """ Compute the log-likelihood of a Gaussian distribution for given data `x`. Args: x (nn.Variable): Target data. It is assumed that the values are ranged [-1, 1], which are originally [0, orig_max_val]. means (nn.Variable): Gaussian mean. Must be the same shape as x. logstd (nn.Variable): Gaussian log standard deviation. Must be the same shape as x. orig_max_val (int): The maximum value that x originally has before being rescaled. Return: A log probabilies of x in nats. """ assert x.shape == mean.shape == logstd.shape centered_x = x - mean inv_std = F.exp(-logstd) half_bin = 1.0 / orig_max_val def clamp(val): # Here we don't need to clip max return F.clip_by_value(val, min=1e-12, max=1e8) # x + 0.5 (in original scale) plus_in = inv_std * (centered_x + half_bin) cdf_plus = approx_standard_normal_cdf(plus_in) log_cdf_plus = F.log(clamp(cdf_plus)) # x - 0.5 (in original scale) minus_in = inv_std * (centered_x - half_bin) cdf_minus = approx_standard_normal_cdf(minus_in) log_one_minus_cdf_minus = F.log(clamp(1.0 - cdf_minus)) log_cdf_delta = F.log(clamp(cdf_plus - cdf_minus)) log_probs = F.where( F.less_scalar(x, -0.999), log_cdf_plus, # Edge case for 0. It uses cdf for -inf as cdf_minus. F.where(F.greater_scalar(x, 0.999), # Edge case for orig_max_val. It uses cdf for +inf as cdf_plus. log_one_minus_cdf_minus, log_cdf_delta # otherwise ) ) assert log_probs.shape == x.shape return log_probs
def sample_pdf(bins, weights, N_samples, det=False): """Sample additional points for training fine network Args: bins: int. Height in pixels. weights: int. Width in pixels. N_samples: float. Focal length of pinhole camera. det Returns: samples: array of shape [batch_size, 3]. Depth samples for fine network """ weights += 1e-5 pdf = weights / F.sum(weights, axis=-1, keepdims=True) cdf = F.cumsum(pdf, axis=-1) # if isinstance(pdf, nn.Variable): # cdf = nn.Variable.from_numpy_array(tf.math.cumsum(pdf.d, axis=-1)) # else: # cdf = nn.Variable.from_numpy_array(tf.math.cumsum(pdf.data, axis=-1)).data cdf = F.concatenate(F.constant(0, cdf[..., :1].shape), cdf, axis=-1) if det: u = F.arange(0., 1., 1 / N_samples) u = F.broadcast(u[None, :], cdf.shape[:-1] + (N_samples, )) u = u.data if isinstance(cdf, nn.NdArray) else u else: u = F.rand(shape=cdf.shape[:-1] + (N_samples, )) indices = F.searchsorted(cdf, u, right=True) # if isinstance(cdf, nn.Variable): # indices = nn.Variable.from_numpy_array( # tf.searchsorted(cdf.d, u.d, side='right').numpy()) # else: # indices = nn.Variable.from_numpy_array( # tf.searchsorted(cdf.data, u.data, side='right').numpy()) below = F.maximum_scalar(indices - 1, 0) above = F.minimum_scalar(indices, cdf.shape[-1] - 1) indices_g = F.stack(below, above, axis=below.ndim) cdf_g = F.gather(cdf, indices_g, axis=-1, batch_dims=len(indices_g.shape) - 2) bins_g = F.gather(bins, indices_g, axis=-1, batch_dims=len(indices_g.shape) - 2) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = F.where(F.less_scalar(denom, 1e-5), F.constant(1, denom.shape), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def loss_dis_real(logits, rec_imgs, part, img, lmd=1.0): # loss = 0.0 # Hinge loss (following the official implementation) loss = F.mean(F.relu(0.2*F.rand(shape=logits.shape) + 0.8 - logits)) # Reconstruction loss for rec_img_big (reconstructed from 8x8 features of the original image) # Reconstruction loss for rec_img_small (reconstructed from 8x8 features of the resized image) # Reconstruction loss for rec_img_part (reconstructed from a part of 16x16 features of the original image) if lmd > 0.0: # Ground-truth img_128 = F.interpolate(img, output_size=(128, 128)) img_256 = F.interpolate(img, output_size=(256, 256)) img_half = F.where(F.greater_scalar( part[0], 0.5), img_256[:, :, :128, :], img_256[:, :, 128:, :]) img_part = F.where(F.greater_scalar( part[1], 0.5), img_half[:, :, :, :128], img_half[:, :, :, 128:]) # Integrated perceptual loss loss = loss + lmd * \ reconstruction_loss_lpips(rec_imgs, [img_128, img_part]) return loss
def sinc_backward(inputs): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] x0 = inputs[1] m0 = F.not_equal_scalar(x0, 0) m0 = no_grad(m0) y0 = get_output(x0, "Sinc") dx0 = dy * (F.cos(x0) - y0) / x0 c0 = F.constant(0, x0.shape) dx0 = F.where(m0, dx0, c0) return dx0
def where_backward(inputs): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] cd = inputs[1] xt = inputs[2] xf = inputs[3] c1 = F.constant(1, xt.shape) c0 = F.constant(0, xf.shape) m0 = F.where(cd, c1, c0) m1 = 1 - m0 m0 = no_grad(m0) m1 = no_grad(m1) dx0 = dy * m0 dx1 = dy * m1 return None, dx0, dx1
def augment(batch, aug_list, p_aug=1.0): if isinstance(p_aug, float): p_aug = nn.Variable.from_numpy_array(p_aug * np.ones((1,))) if "flip" in aug_list: rnd = F.rand(shape=[batch.shape[0], ]) batch_aug = F.random_flip(batch, axes=(2, 3)) batch = F.where( F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch) if "lrflip" in aug_list: rnd = F.rand(shape=[batch.shape[0], ]) batch_aug = F.random_flip(batch, axes=(3,)) batch = F.where( F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch) if "translation" in aug_list and batch.shape[2] >= 8: rnd = F.rand(shape=[batch.shape[0], ]) # Currently nnabla does not support random_shift with border_mode="noise" mask = np.ones((1, 3, batch.shape[2], batch.shape[3])) mask[:, :, :, 0] = 0 mask[:, :, :, -1] = 0 mask[:, :, 0, :] = 0 mask[:, :, -1, :] = 0 batch_int = F.concatenate( batch, nn.Variable().from_numpy_array(mask), axis=0) batch_int_aug = F.random_shift(batch_int, shifts=( batch.shape[2]//8, batch.shape[3]//8), border_mode="nearest") batch_aug = F.slice(batch_int_aug, start=( 0, 0, 0, 0), stop=batch.shape) mask_var = F.slice(batch_int_aug, start=( batch.shape[0], 0, 0, 0), stop=batch_int_aug.shape) batch_aug = batch_aug * F.broadcast(mask_var, batch_aug.shape) batch = F.where( F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch) if "color" in aug_list: rnd = F.rand(shape=[batch.shape[0], ]) rnd_contrast = 1.0 + 0.5 * \ (2.0 * F.rand(shape=[batch.shape[0], 1, 1, 1] ) - 1.0) # from 0.5 to 1.5 rnd_brightness = 0.5 * \ (2.0 * F.rand(shape=[batch.shape[0], 1, 1, 1] ) - 1.0) # from -0.5 to 0.5 rnd_saturation = 2.0 * \ F.rand(shape=[batch.shape[0], 1, 1, 1]) # from 0.0 to 2.0 # Brightness batch_aug = batch + rnd_brightness # Saturation mean_s = F.mean(batch_aug, axis=1, keepdims=True) batch_aug = rnd_saturation * (batch_aug - mean_s) + mean_s # Contrast mean_c = F.mean(batch_aug, axis=(1, 2, 3), keepdims=True) batch_aug = rnd_contrast * (batch_aug - mean_c) + mean_c batch = F.where( F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch) if "cutout" in aug_list and batch.shape[2] >= 16: batch = F.random_erase(batch, prob=p_aug.d[0], replacements=(0.0, 0.0)) return batch
def projection(x: nn.NdArray, eps: float = 1e-5) -> nn.NdArray: norm = F.pow_scalar(F.sum(x**2, axis=1), val=0.5) return F.where(condition=F.greater_equal_scalar(norm, val=1.), x_true=F.clip_by_norm(x, clip_norm=1 - eps, axis=1), x_false=x)
def Discriminator(img, label="real", scope_name="Discriminator", ndf=64): with nn.parameter_scope(scope_name): if type(img) is not list: img_small = F.interpolate(img, output_size=(128, 128)) else: img_small = img[1] img = img[0] def sn_w(w): return PF.spectral_norm(w, dim=0) # InitLayer: -> 256x256 with nn.parameter_scope("init"): h = img if img.shape[2] == 1024: h = PF.convolution(h, ndf // 8, (4, 4), stride=(2, 2), pad=(1, 1), apply_w=sn_w, with_bias=False, name="conv1") h = F.leaky_relu(h, 0.2) h = PF.convolution(h, ndf // 4, (4, 4), stride=(2, 2), pad=(1, 1), apply_w=sn_w, with_bias=False, name="conv2") h = PF.batch_normalization(h) h = F.leaky_relu(h, 0.2) elif img.shape[2] == 512: h = PF.convolution(h, ndf // 4, (4, 4), stride=(2, 2), pad=(1, 1), apply_w=sn_w, with_bias=False, name="conv2") h = F.leaky_relu(h, 0.2) else: h = PF.convolution(h, ndf // 4, (3, 3), pad=(1, 1), apply_w=sn_w, with_bias=False, name="conv3") h = F.leaky_relu(h, 0.2) # Calc base features f_256 = h f_128 = DownsampleComp(f_256, ndf // 2, "down256->128") f_64 = DownsampleComp(f_128, ndf * 1, "down128->64") f_32 = DownsampleComp(f_64, ndf * 2, "down64->32") # Apply SLE f_32 = SLE(f_32, f_256, "sle256->32") f_16 = DownsampleComp(f_32, ndf * 4, "down32->16") f_16 = SLE(f_16, f_128, "sle128->16") f_8 = DownsampleComp(f_16, ndf * 16, "down16->8") f_8 = SLE(f_8, f_64, "sle64->8") # Conv + BN + LeakyRely + Conv -> logits (5x5) with nn.parameter_scope("last"): h = PF.convolution(f_8, ndf * 16, (1, 1), apply_w=sn_w, with_bias=False, name="conv1") h = PF.batch_normalization(h) h = F.leaky_relu(h, 0.2) logit_large = PF.convolution(h, 1, (4, 4), apply_w=sn_w, with_bias=False, name="conv2") # Another path: "down_from_small" in the official code with nn.parameter_scope("down_from_small"): h_s = PF.convolution(img_small, ndf // 2, (4, 4), stride=(2, 2), pad=(1, 1), apply_w=sn_w, with_bias=False, name="conv1") h_s = F.leaky_relu(h_s, 0.2) h_s = Downsample(h_s, ndf * 1, "dfs64->32") h_s = Downsample(h_s, ndf * 2, "dfs32->16") h_s = Downsample(h_s, ndf * 4, "dfs16->8") fea_dec_small = h_s logit_small = PF.convolution(h_s, 1, (4, 4), apply_w=sn_w, with_bias=False, name="conv2") # Concatenate logits logits = F.concatenate(logit_large, logit_small, axis=1) # Reconstruct images rec_img_big = SimpleDecoder(f_8, "dec_big") rec_img_small = SimpleDecoder(fea_dec_small, "dec_small") part_ax2 = F.rand(shape=(img.shape[0], )) part_ax3 = F.rand(shape=(img.shape[0], )) f_16_ax2 = F.where(F.greater_scalar(part_ax2, 0.5), f_16[:, :, :8, :], f_16[:, :, 8:, :]) f_16_part = F.where(F.greater_scalar(part_ax3, 0.5), f_16_ax2[:, :, :, :8], f_16_ax2[:, :, :, 8:]) rec_img_part = SimpleDecoder(f_16_part, "dec_part") if label == "real": return logits, [rec_img_big, rec_img_small, rec_img_part], [part_ax2, part_ax3] else: return logits