def style_mixing(self, test_config, args): from nnabla.utils.image_utils import imsave, imresize print('Testing style mixing of generation...') z1 = F.randn(shape=(args.batch_size_A, test_config['latent_dim']), seed=args.seed_1[0]).data z2 = F.randn(shape=(args.batch_size_B, test_config['latent_dim']), seed=args.seed_2[0]).data nn.set_auto_forward(True) mix_image_stacks = [] for i in range(args.batch_size_A): image_column = [] for j in range(args.batch_size_B): style_noises = [ F.reshape(z1[i], (1, 512)), F.reshape(z2[j], (1, 512)) ] rgb_output = self.generator( 1, style_noises, test_config['truncation_psi'], mixing_layer_index=test_config['mix_after']) image = save_generations(rgb_output, None, return_images=True) image_column.append(image[0]) image_column = np.concatenate([image for image in image_column], axis=1) mix_image_stacks.append(image_column) mix_image_stacks = np.concatenate( [image for image in mix_image_stacks], axis=2) style_noises = [z1, z1] rgb_output = self.generator(args.batch_size_A, style_noises, test_config['truncation_psi']) image_A = save_generations(rgb_output, None, return_images=True) image_A = np.concatenate([image for image in image_A], axis=2) style_noises = [z2, z2] rgb_output = self.generator(args.batch_size_B, style_noises, test_config['truncation_psi']) image_B = save_generations(rgb_output, None, return_images=True) image_B = np.concatenate([image for image in image_B], axis=1) top_image = 255 * np.ones(rgb_output[0].shape).astype(np.uint8) top_image = np.concatenate((top_image, image_A), axis=2) grid_image = np.concatenate((image_B, mix_image_stacks), axis=2) grid_image = np.concatenate((top_image, grid_image), axis=1) filename = os.path.join(self.results_dir, 'style_mix.png') imsave(filename, imresize(grid_image, (1024, 1024), channel_first=True), channel_first=True) print(f'Output saved as {filename}')
def ce_loss_with_uncertainty(ctx, pred, y_l, log_var): r = F.randn(0., 1., log_var.shape) r = F.pow_scalar(F.exp(log_var), 0.5) * r h = pred + r with nn.context_scope(ctx): loss_ce = F.mean(F.softmax_cross_entropy(h, y_l)) return loss_ce
def ce_loss_with_uncertainty(ctx, pred, y_l, log_var): r = F.randn(0., 1., log_var.shape) r = F.pow_scalar(F.exp(log_var), 0.5) * r h = pred + r with nn.context_scope(ctx): loss_ce = F.mean(F.softmax_cross_entropy(h, y_l)) return loss_ce
def test_randn_forward_backward(seed, ctx, func_name, mu, sigma, shape): with nn.context_scope(ctx): o = F.randn(mu, sigma, shape, seed=seed) assert o.shape == tuple(shape) assert o.parent.name == func_name o.forward() if o.size >= 10000: est_mu = o.d.mean() est_sigma = o.d.std() np.isclose(est_mu, mu, atol=sigma) np.isclose(est_sigma, sigma, atol=sigma) else: data = [] for i in range(10000): o.forward() data += [o.d.copy()] est_mu = np.mean(np.array(data)) est_sigma = np.std(np.array(data)) np.isclose(est_mu, mu, atol=sigma) np.isclose(est_sigma, sigma, atol=sigma) # Checking recomputation func_args = [mu, sigma, shape, seed] recomputation_test(rng=None, func=F.randn, vinputs=[], func_args=func_args, func_kwargs={}, ctx=ctx)
def sample_noise(inpt_size, out_size): _f = lambda x: F.sign(x) * F.pow_scalar(F.abs(x), 0.5) noise = _f(F.randn(shape=(inpt_size + out_size, ))) eps_w = F.batch_matmul(F.reshape(noise[:inpt_size], (1, -1)), F.reshape(noise[inpt_size:], (1, -1)), True) eps_b = noise[inpt_size:] return eps_w, eps_b
def __init__(self, bs, **kwargs): noise = F.randn(mu=0, sigma=kwargs['sigma_affine'], shape=(bs, 2, 3)) self.theta = noise + \ nn.Variable.from_numpy_array( np.array([[[1., 0., 0.], [0., 1., 0.]]])) self.bs = bs if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): self.tps = True self.control_points = make_coordinate_grid( (kwargs['points_tps'], kwargs['points_tps'])) self.control_points = F.reshape( self.control_points, (1,) + self.control_points.shape) self.control_params = F.randn( mu=0, sigma=kwargs['sigma_tps'], shape=(bs, 1, kwargs['points_tps'] ** 2)) else: self.tps = False
def graph(x1): x1 = F.identity(x1).apply(recompute=True) x2 = F.randn(shape=x1.shape, seed=123).apply(recompute=True) x3 = F.rand(shape=x1.shape, seed=456).apply(recompute=True) y = F.mul2(x1, x2).apply(recompute=True) y = F.mul2(y, x3).apply(recompute=True) y = F.identity(y) return y
def build_static_graph(self): real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size, self.img_size)) noises = [ F.randn(shape=(self.batch_size, self.config['latent_dim'])) for _ in range(2) ] if self.config['regularize_gen']: fake_img, dlatents = self.generator(self.batch_size, noises, return_latent=True) else: fake_img = self.generator(self.batch_size, noises) fake_img_test = self.generator_ema(self.batch_size, noises) gen_loss = gen_nonsaturating_loss(self.discriminator(fake_img)) fake_disc_out = self.discriminator(fake_img) real_disc_out = self.discriminator(real_img) disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out) var_name_list = [ 'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss', 'fake_disc_out', 'real_disc_out', 'fake_img_test' ] var_list = [ real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out, real_disc_out, fake_img_test ] if self.config['regularize_gen']: dlatents.need_grad = True mean_path_length = nn.Variable() pl_reg, path_mean, _ = gen_path_regularize( fake_img=fake_img, latents=dlatents, mean_path_length=mean_path_length) path_mean_update = F.assign(mean_path_length, path_mean) path_mean_update.name = 'path_mean_update' pl_reg += 0 * path_mean_update gen_loss_reg = gen_loss + pl_reg var_name_list.append('gen_loss_reg') var_list.append(gen_loss_reg) if self.config['regularize_disc']: real_img.need_grad = True real_disc_out = self.discriminator(real_img) disc_loss_reg = disc_loss + self.config[ 'r1_coeff'] * 0.5 * disc_r1_loss( real_disc_out, real_img) * self.config['disc_reg_step'] real_img.need_grad = False var_name_list.append('disc_loss_reg') var_list.append(disc_loss_reg) Parameters = namedtuple('Parameters', var_name_list) self.parameters = Parameters(*var_list)
def generate_z_anchor(self): z_anchor_list = [] for _ in range(2): z_anchor_var = F.gather( self.init_z_var, combination( self.n_train, self.batch_size)) + F.randn( sigma=self.anch_std, shape=(self.batch_size, self.latent_dim)) z_anchor_list.append(z_anchor_var) return z_anchor_list
def sample(self, mu, logvar): r"""Samples from a Gaussian distribution. Args: mu (nn.Variable): Mean of the distribution of shape (B, D, 1). logvar (nn.Variable): Log variance of the distribution of shape (B, D, 1). Returns: nn.Variable: A sample. """ if self.training: eps = F.randn(shape=mu.shape) return mu + F.exp(0.5 * logvar) * eps return mu
def test_randn_forward_backward(seed, ctx, func_name, mu, sigma, shape): with nn.context_scope(ctx): o = F.randn(mu, sigma, shape, seed=seed) assert o.shape == tuple(shape) assert o.parent.name == func_name o.forward() if o.size >= 10000: est_mu = o.d.mean() est_sigma = o.d.std() np.isclose(est_mu, mu, atol=sigma) np.isclose(est_sigma, sigma, atol=sigma) else: data = [] for i in range(10000): o.forward() data += [o.d.copy()] est_mu = np.mean(np.array(data)) est_sigma = np.std(np.array(data)) np.isclose(est_mu, mu, atol=sigma) np.isclose(est_sigma, sigma, atol=sigma)
def gen_path_regularize(fake_img, latents, mean_path_length, decay=0.01, pl_weight=2.0): noise = F.randn(shape=fake_img.shape) / \ np.sqrt(fake_img.shape[2]*fake_img.shape[3]) gradient = nn.grad([F.sum(fake_img * noise)], [latents])[0] path_lengths = F.mean(F.sum(F.pow_scalar(gradient, 2), axis=1), axis=0) path_lengths = F.pow_scalar(path_lengths, 0.5) path_mean = mean_path_length + decay * \ (F.mean(path_lengths) - mean_path_length) path_penalty = F.mean( F.pow_scalar(path_lengths - F.reshape(path_mean, (1, ), inplace=False), 1)) return path_penalty * pl_weight, path_mean, path_lengths
def test_randn_forward_backward(seed, ctx, func_name, mu, sigma, shape): with nn.context_scope(ctx): o = F.randn(mu, sigma, shape, seed=seed) assert o.shape == tuple(shape) assert o.parent.name == func_name o.forward() if o.size >= 10000: est_mu = o.d.mean() est_sigma = o.d.std() np.isclose(est_mu, mu, atol=sigma) np.isclose(est_sigma, sigma, atol=sigma) else: data = [] for i in range(10000): o.forward() data += [o.d.copy()] est_mu = np.mean(np.array(data)) est_sigma = np.std(np.array(data)) np.isclose(est_mu, mu, atol=sigma) np.isclose(est_sigma, sigma, atol=sigma)
def q_sample(self, x_start, t, noise=None): """ Diffuse the data (t == 0 means diffused for 1 step), which samples from q(x_t | x_0). xt = sqrt(cumprod(alpha_0, ..., alpha_t)) * x_0 + sqrt(1 - cumprod(alpha_0, ..., alpha_t)) * epsilon Args: x_start (nn.Variable): The (B, C, ...) tensor of x_0. t (nn.Variable): A 1-D tensor of timesteps. Return: x_t (nn.Variable): The (B, C, ...) tensor of x_t. Each sample x_t[i] corresponds to the noisy image at timestep t[i] constructed from x_start[i]. """ if noise is None: noise = F.randn(shape=x_start.shape) assert noise.shape == x_start.shape return ( self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise )
def define_network(self): if self.use_inst: obj_onehot, bm = encode_inputs(self.ist_mask, self.obj_mask, n_ids=self.conf.n_class) mask = F.concatenate(obj_onehot, bm, axis=1) else: om = self.obj_mask if len(om.shape) == 3: om = F.reshape(om, om.shape + (1, )) obj_onehot = F.one_hot(om, shape=(self.conf.n_class, )) mask = F.transpose(obj_onehot, (0, 3, 1, 2)) generator = SpadeGenerator(self.conf.g_ndf, image_shape=self.conf.image_shape) z = F.randn(shape=(self.conf.batch_size, self.conf.z_dim)) fake = generator(z, mask) # Pixel intensities of fake are [-1, 1]. Rescale it to [0, 1] fake = (fake + 1) / 2 return fake
def _smoothing_target(policy_tp1, sigma, c): noise_shape = policy_tp1.shape smoothing_noise = F.randn(sigma=sigma, shape=noise_shape) clipped_noise = clip_by_value(smoothing_noise, -c, c) return clip_by_value(policy_tp1 + clipped_noise, -1.0, 1.0)
def volumetric_rendering(radiance_field, ray_origins, depth_values, return_weights=False, white_bkgd=False, raw_noise_std=0.0, apply_act=False): """Integration of volumetric rendering Args: radiance_field (nn.Variable or nn.NdArray): Shape is (height, width, num_samples, 4). radiance_field[:,:,:,:3] correspond to rgb value at each sampled point while radiance_field[:,:,:,-1] refers to color density. ray_origins (nn.Variable or nn.NdArray): Shape is (height, width, 3) depth_values (nn.Variable or nn.NdArray): Shape is (num_samples, 1) or (height, width, num_samples) return_weights (bool, optional): Set to true if the coefficients of the volumetric integration sum are to be returned . Defaults to False. Returns: rgb_map (nn.Variable or nn.NdArray): Shape is (height, width, 3) rgb_map (nn.Variable or nn.NdArray): Shape is (height, width, 1) """ if apply_act: sigma = F.relu(radiance_field[..., 3]) rgb = F.sigmoid(radiance_field[..., :3]) else: sigma = radiance_field[..., 3] rgb = radiance_field[..., :3] if raw_noise_std > 0.0: noise = F.randn(shape=sigma.shape) sigma += (noise * raw_noise_std) if depth_values.ndim == 2: distances = depth_values[:, 1:] - depth_values[:, :-1] distances = F.concatenate(distances, F.constant(1e2, shape=depth_values.shape[:-1] + (1, )), axis=-1) alpha = 1. - F.exp(-sigma * distances) weights = alpha * F.cumprod(1 - alpha + 1e-10, axis=-1, exclusive=True) rgb_map = F.sum(weights[..., None] * rgb, axis=-2) depth_map = F.sum(weights * depth_values, axis=-1) acc_map = F.sum(weights, axis=-1) else: distances = depth_values[:, :, 1:] - depth_values[:, :, :-1] distances = F.concatenate(distances, F.constant(1e10, shape=depth_values.shape[:-1] + (1, )), axis=-1) alpha = 1. - F.exp(-sigma * distances) rgb_map = F.sum(weights[..., None] * rgb, axis=rgb.ndim - 2) depth_map = F.sum(weights * depth_values, axis=1) acc_map = F.sum(weights, axis=-1) if white_bkgd: rgb_map = rgb_map + (1. - acc_map[..., None]) if return_weights: disp_map = 1.0 / \ F.maximum2(F.constant(1e-10, depth_map.shape), depth_map / acc_map) return rgb_map, depth_map, acc_map, disp_map, weights return rgb_map, depth_map, acc_map
def generate_z_normal(self): z_normal_list = [ F.randn(shape=(self.batch_size, self.latent_dim)) for _ in range(2) ] return z_normal_list
def generate_data(args): if not os.path.isfile(os.path.join(args.weights_path, 'gen_params.h5')): os.makedirs(args.weights_path, exist_ok=True) print( "Downloading the pretrained tf-converted weights. Please wait...") url = "https://nnabla.org/pretrained-models/nnabla-examples/GANs/stylegan2/styleGAN2_G_params.h5" from nnabla.utils.data_source_loader import download download(url, os.path.join(args.weights_path, 'gen_params.h5'), False) nn.load_parameters(os.path.join(args.weights_path, 'gen_params.h5')) print('Loaded pretrained weights from tensorflow!') os.makedirs(args.save_image_path, exist_ok=True) batches = [ args.batch_size for _ in range(args.num_images // args.batch_size) ] if args.num_images % args.batch_size != 0: batches.append(args.num_images - (args.num_images // args.batch_size) * args.batch_size) for idx, batch_size in enumerate(batches): z = [ F.randn(shape=(batch_size, 512)).data, F.randn(shape=(batch_size, 512)).data ] for i in range(len(z)): z[i] = F.div2( z[i], F.pow_scalar(F.add_scalar( F.mean(z[i]**2., axis=1, keepdims=True), 1e-8), 0.5, inplace=True)) # get latent code w = [mapping_network(z[0], outmaps=512, num_layers=8)] w += [mapping_network(z[1], outmaps=512, num_layers=8)] # truncation trick dlatent_avg = nn.parameter.get_parameter_or_create(name="dlatent_avg", shape=(1, 512)) w = [lerp(dlatent_avg, _, 0.7) for _ in w] # Load direction if not args.face_morph: attr_delta = nn.NdArray.from_numpy_array( np.load(args.attr_delta_path)) attr_delta = F.reshape(attr_delta[0], (1, -1)) w_plus = [w[0] + args.coeff * attr_delta, w[1]] w_minus = [w[0] - args.coeff * attr_delta, w[1]] else: w_plus = [w[0], w[0]] # content w_minus = [w[1], w[1]] # style constant_bc = nn.parameter.get_parameter_or_create( name="G_synthesis/4x4/Const/const", shape=(1, 512, 4, 4)) constant_bc = F.broadcast(constant_bc, (batch_size, ) + constant_bc.shape[1:]) gen_plus = synthesis(w_plus, constant_bc, noise_seed=100, mix_after=8) gen_minus = synthesis(w_minus, constant_bc, noise_seed=100, mix_after=8) gen = synthesis(w, constant_bc, noise_seed=100, mix_after=8) image_plus = convert_images_to_uint8(gen_plus, drange=[-1, 1]) image_minus = convert_images_to_uint8(gen_minus, drange=[-1, 1]) image = convert_images_to_uint8(gen, drange=[-1, 1]) for j in range(batch_size): filepath = os.path.join(args.save_image_path, f'image_{idx*batch_size+j}') imsave(f'{filepath}_o.png', image_plus[j], channel_first=True) imsave(f'{filepath}_y.png', image_minus[j], channel_first=True) imsave(f'{filepath}.png', image[j], channel_first=True) print(f"Genetated. Saved {filepath}")
def vae(x, shape_z, test=False): """ Function for calculate Elbo(evidence lowerbound) loss. This sample is a Bernoulli generator version. Args: x(`~nnabla.Variable`): N-D array shape_z(tuple of int): size of z test : True=train, False=test Returns: ~nnabla.Variable: Elbo loss """ ############################################# # Encoder of 2 fully connected layers # ############################################# # Normalize input xa = x / 256. batch_size = x.shape[0] # 2 fully connected layers, and Elu replaced from original Softplus. h = F.elu(PF.affine(xa, (500, ), name='fc1')) h = F.elu(PF.affine(h, (500, ), name='fc2')) # The outputs are the parameters of Gauss probability density. mu = PF.affine(h, shape_z, name='fc_mu') logvar = PF.affine(h, shape_z, name='fc_logvar') sigma = F.exp(0.5 * logvar) # The prior variable and the reparameterization trick if not test: # training with reparameterization trick epsilon = F.randn(mu=0, sigma=1, shape=(batch_size, ) + shape_z) z = mu + sigma * epsilon else: # test without randomness z = mu ############################################# # Decoder of 2 fully connected layers # ############################################# # 2 fully connected layers, and Elu replaced from original Softplus. h = F.elu(PF.affine(z, (500, ), name='fc3')) h = F.elu(PF.affine(h, (500, ), name='fc4')) # The outputs are the parameters of Bernoulli probabilities for each pixel. prob = PF.affine(h, (1, 28, 28), name='fc5') ############################################# # Elbo components and loss objective # ############################################# # Binarized input xb = F.greater_equal_scalar(xa, 0.5) # E_q(z|x)[log(q(z|x))] # without some constant terms that will canceled after summation of loss logqz = 0.5 * F.sum(1.0 + logvar, axis=1) # E_q(z|x)[log(p(z))] # without some constant terms that will canceled after summation of loss logpz = 0.5 * F.sum(mu * mu + sigma * sigma, axis=1) # E_q(z|x)[log(p(x|z))] logpx = F.sum(F.sigmoid_cross_entropy(prob, xb), axis=(1, 2, 3)) # Vae loss, the negative evidence lowerbound loss = F.mean(logpx + logpz - logqz) return loss
def train_loss(self, model, x_start, t, noise=None): """ Calculate training loss for given data and model. Args: model (callable): A trainable model to predict noise in data conditioned by timestep. This function should perform like pred_noise = model(x_noisy, t). If self.model_var_type is the one that requires prediction for sigma, model has to output them as well. x_start (nn.Variable): The (B, C, ...) tensor of x_0. t (nn.Variable): A 1-D tensor of timesteps. noise (callable or None): A noise generator. If None, F.randn(shape=x_start.shape) will be used. Return: loss (dict of {string: nn.Variable}): Return dict that has losses to train the `model`. You can access each loss by a name that will be: - `vlb`: Variational Lower Bound for learning sigma. This will be included only if self.model_var_type requires to learn sigma. - `mse`: MSE between actual and predicted noise. Each entry is the (B, ) tensor of batched loss computed from given inputs. Note that this function doesn't reduce batch dim in order to make it easy to trace the loss value at each timestep. Therefore, you should take average for returned Variable over batch dim to train the model. """ B, C, H, W = x_start.shape assert t.shape == (B, ) if noise is None: noise = F.randn(shape=x_start.shape) assert noise.shape == x_start.shape # Calculate x_t from x_start, t, and noise. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) assert x_noisy.shape == x_start.shape # Predict noise. # According to the original DDPM, this is superior than reconstructing x_0. # If model_var_type requires to learn sigma, model must output pred_sigma as well. pred = model(x_noisy, t) # Calculate losses ret = AttrDict() if is_learn_sigma(self.model_var_type): # split pred into 2 variables along channel axis. pred_noise, pred_sigma = chunk(pred, num_chunk=2, axis=1) assert pred_sigma.shape == x_start.shape, \ f"Shape mismutch between pred_sigma {pred_sigma.shape} and x_start {x_start.shape}" # Variational lower bound for sigma # Use dummy function as model, since we already got prediction from model. var = F.concatenate(pred_noise.get_unlinked_variable( need_grad=False), pred_sigma, axis=1) ret.vlb = self._vlb_in_bits_per_dims(model=lambda x_t, t: var, x_start=x_start, x_t=x_noisy, t=t) else: pred_noise = pred assert pred_noise.shape == x_start.shape, \ f"Shape mismutch between pred_noise {pred_sigma.shape} and x_start {x_start.shape}" ret.mse = mean_along_except_batch(F.squared_error(noise, pred_noise)) # shape check for all losses for name, loss in ret.items(): assert loss.shape == (B, ), \ f"A Variabla for loss `{name}` has a wrong shape ({loss.shape} != {(B, )})" return ret
def build_static_graph(self): real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size, self.img_size)) noises = [ F.randn(shape=(self.batch_size, self.config['latent_dim'])) for _ in range(2) ] if self.few_shot_config['common']['type'] == 'cdc': NT_class = NoiseTop(n_train=self.train_loader.size, latent_dim=self.config['latent_dim'], batch_size=self.batch_size) noises = NT_class() self.PD_switch_var = NT_class.PD_switch_var if self.config['regularize_gen']: fake_img, dlatents = self.generator(self.batch_size, noises, return_latent=True) else: fake_img = self.generator(self.batch_size, noises) fake_img_test = self.generator_ema(self.batch_size, noises) if self.few_shot_config['common']['type'] != 'cdc': fake_disc_out = self.discriminator(fake_img) real_disc_out = self.discriminator(real_img) disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out) gen_loss = 0 if self.few_shot_config['common']['type'] == 'cdc': fake_img_s = self.generator_s(self.batch_size, noises) cdc_loss = CrossDomainCorrespondence( fake_img, fake_img_s, _choice_num=self.few_shot_config['cdc']['feature_num'], _layer_fix_switch=self.few_shot_config['cdc']['layer_fix']) gen_loss += self.few_shot_config['cdc']['lambda'] * cdc_loss # --- PatchDiscriminator --- fake_disc_out, fake_feature_var = self.discriminator( fake_img, patch_switch=True, index=0) real_disc_out, real_feature_var = self.discriminator( real_img, patch_switch=True, index=0) disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out) disc_loss_patch = disc_logistic_loss(fake_feature_var, real_feature_var) disc_loss += self.PD_switch_var * disc_loss_patch gen_loss += gen_nonsaturating_loss(fake_disc_out) var_name_list = [ 'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss', 'fake_disc_out', 'real_disc_out', 'fake_img_test' ] var_list = [ real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out, real_disc_out, fake_img_test ] if self.config['regularize_gen']: dlatents.need_grad = True mean_path_length = nn.Variable() pl_reg, path_mean, _ = gen_path_regularize( fake_img=fake_img, latents=dlatents, mean_path_length=mean_path_length) path_mean_update = F.assign(mean_path_length, path_mean) path_mean_update.name = 'path_mean_update' pl_reg += 0 * path_mean_update gen_loss_reg = gen_loss + pl_reg var_name_list.append('gen_loss_reg') var_list.append(gen_loss_reg) if self.config['regularize_disc']: real_img.need_grad = True real_disc_out = self.discriminator(real_img) disc_loss_reg = disc_loss + self.config[ 'r1_coeff'] * 0.5 * disc_r1_loss( real_disc_out, real_img) * self.config['disc_reg_step'] real_img.need_grad = False var_name_list.append('disc_loss_reg') var_list.append(disc_loss_reg) Parameters = namedtuple('Parameters', var_name_list) self.parameters = Parameters(*var_list)
def generate_attribute_direction(args, attribute_prediction_model): if not os.path.isfile(os.path.join(args.weights_path, 'gen_params.h5')): os.makedirs(args.weights_path, exist_ok=True) print( "Downloading the pretrained tf-converted weights. Please wait...") url = "https://nnabla.org/pretrained-models/nnabla-examples/GANs/stylegan2/styleGAN2_G_params.h5" from nnabla.utils.data_source_loader import download download(url, os.path.join(args.weights_path, 'gen_params.h5'), False) nn.load_parameters(os.path.join(args.weights_path, 'gen_params.h5')) print('Loaded pretrained weights from tensorflow!') nn.load_parameters(args.classifier_weight_path) print(f'Loaded {args.classifier_weight_path}') batches = [ args.batch_size for _ in range(args.num_images // args.batch_size) ] if args.num_images % args.batch_size != 0: batches.append(args.num_images - (args.num_images // args.batch_size) * args.batch_size) w_plus, w_minus = 0.0, 0.0 w_plus_count, w_minus_count = 0.0, 0.0 pbar = trange(len(batches)) for i in pbar: batch_size = batches[i] z = [F.randn(shape=(batch_size, 512)).data] z = [z[0], z[0]] for i in range(len(z)): z[i] = F.div2( z[i], F.pow_scalar(F.add_scalar( F.mean(z[i]**2., axis=1, keepdims=True), 1e-8), 0.5, inplace=True)) # get latent code w = [mapping_network(z[0], outmaps=512, num_layers=8)] w += [mapping_network(z[1], outmaps=512, num_layers=8)] # truncation trick dlatent_avg = nn.parameter.get_parameter_or_create(name="dlatent_avg", shape=(1, 512)) w = [lerp(dlatent_avg, _, 0.7) for _ in w] constant_bc = nn.parameter.get_parameter_or_create( name="G_synthesis/4x4/Const/const", shape=(1, 512, 4, 4)) constant_bc = F.broadcast(constant_bc, (batch_size, ) + constant_bc.shape[1:]) gen = synthesis(w, constant_bc, noise_seed=100, mix_after=7) classifier_score = F.softmax(attribute_prediction_model(gen, True)) confidence, class_pred = F.max(classifier_score, axis=1, with_index=True, keepdims=True) w_plus += np.sum(w[0].data * (class_pred.data == 0) * (confidence.data > 0.65), axis=0, keepdims=True) w_minus += np.sum(w[0].data * (class_pred.data == 1) * (confidence.data > 0.65), axis=0, keepdims=True) w_plus_count += np.sum( (class_pred.data == 0) * (confidence.data > 0.65)) w_minus_count += np.sum( (class_pred.data == 1) * (confidence.data > 0.65)) pbar.set_description(f'{w_plus_count} {w_minus_count}') # save attribute direction attribute_variation_direction = (w_plus / w_plus_count) - (w_minus / w_minus_count) print(w_plus_count, w_minus_count) np.save(f'{args.classifier_weight_path.split("/")[0]}/direction.npy', attribute_variation_direction)
def train(): rng = np.random.RandomState(803) conf = get_config() comm = init_nnabla(conf) # create data iterator if conf.dataset == "cityscapes": data_list = get_cityscape_datalist(conf.cityscapes, save_file=comm.rank == 0) n_class = conf.cityscapes.n_label_ids use_inst = True data_iter = create_cityscapes_iterator(conf.batch_size, data_list, comm=comm, image_shape=conf.image_shape, rng=rng, flip=conf.use_flip) elif conf.dataset == "ade20k": data_list = get_ade20k_datalist(conf.ade20k, save_file=comm.rank == 0) n_class = conf.ade20k.n_label_ids + 1 # class id + unknown use_inst = False load_shape = tuple( x + 30 for x in conf.image_shape) if conf.use_crop else conf.image_shape data_iter = create_ade20k_iterator(conf.batch_size, data_list, comm=comm, load_shape=load_shape, crop_shape=conf.image_shape, rng=rng, flip=conf.use_flip) else: raise NotImplementedError( "Currently dataset {} is not supported.".format(conf.dataset)) real = nn.Variable(shape=(conf.batch_size, 3) + conf.image_shape) obj_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape) if use_inst: ist_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape) obj_onehot, bm = encode_inputs(ist_mask, obj_mask, n_ids=n_class) mask = F.concatenate(obj_onehot, bm, axis=1) else: om = obj_mask if len(om.shape) == 3: om = F.reshape(om, om.shape + (1, )) obj_onehot = F.one_hot(om, shape=(n_class, )) mask = F.transpose(obj_onehot, (0, 3, 1, 2)) # generator generator = SpadeGenerator(conf.g_ndf, image_shape=conf.image_shape) z = F.randn(shape=(conf.batch_size, conf.z_dim)) fake = generator(z, mask) # unlinking ul_mask, ul_fake = get_unlinked_all(mask, fake) # discriminator discriminator = PatchGAN(n_scales=conf.d_n_scales) d_input_real = F.concatenate(real, ul_mask, axis=1) d_input_fake = F.concatenate(ul_fake, ul_mask, axis=1) d_real_out, d_real_feats = discriminator(d_input_real) d_fake_out, d_fake_feats = discriminator(d_input_fake) g_gan, g_feat, d_real, d_fake = discriminator.get_loss( d_real_out, d_real_feats, d_fake_out, d_fake_feats, use_fm=conf.use_fm, fm_lambda=conf.lambda_fm, gan_loss_type=conf.gan_loss_type) def _rescale(x): return rescale_values(x, input_min=-1, input_max=1, output_min=0, output_max=255) g_vgg = vgg16_perceptual_loss(_rescale(ul_fake), _rescale(real)) * conf.lambda_vgg set_persistent_all(fake, mask, g_gan, g_feat, d_real, d_fake, g_vgg) # loss g_loss = g_gan + g_feat + g_vgg d_loss = (d_real + d_fake) / 2 # load params if conf.load_params is not None: print("load parameters from {}".format(conf.load_params)) nn.load_parameters(conf.load_params) # Setup Solvers g_solver = S.Adam(beta1=0.) g_solver.set_parameters(get_params_startswith("spade_generator")) d_solver = S.Adam(beta1=0.) d_solver.set_parameters(get_params_startswith("discriminator")) # lr scheduler g_lrs = LinearDecayScheduler(start_lr=conf.g_lr, end_lr=0., start_iter=100, end_iter=200) d_lrs = LinearDecayScheduler(start_lr=conf.d_lr, end_lr=0., start_iter=100, end_iter=200) ipe = get_iteration_per_epoch(data_iter._size, conf.batch_size, round="ceil") if not conf.show_interval: conf.show_interval = ipe if not conf.save_interval: conf.save_interval = ipe if not conf.niter: conf.niter = 200 * ipe # Setup Reporter losses = { "g_gan": g_gan, "g_feat": g_feat, "g_vgg": g_vgg, "d_real": d_real, "d_fake": d_fake } reporter = Reporter(comm, losses, conf.save_path, nimage_per_epoch=min(conf.batch_size, 5), show_interval=10) progress_iterator = trange(conf.niter, disable=comm.rank > 0) reporter.start(progress_iterator) colorizer = Colorize(n_class) # output all config and dump to file if comm.rank == 0: conf.dump_to_stdout() write_yaml(os.path.join(conf.save_path, "config.yaml"), conf) epoch = 0 for itr in progress_iterator: if itr % ipe == 0: g_lr = g_lrs(epoch) d_lr = d_lrs(epoch) g_solver.set_learning_rate(g_lr) d_solver.set_learning_rate(d_lr) if comm.rank == 0: print( "\n[epoch {}] update lr to ... g_lr: {}, d_lr: {}".format( epoch, g_lr, d_lr)) epoch += 1 if conf.dataset == "cityscapes": im, ist, obj = data_iter.next() ist_mask.d = ist elif conf.dataset == "ade20k": im, obj = data_iter.next() else: raise NotImplemented() real.d = im obj_mask.d = obj # text embedding and create fake fake.forward() # update discriminator d_solver.zero_grad() d_loss.forward() d_loss.backward(clear_buffer=True) comm.all_reduced_solver_update(d_solver) # update generator ul_fake.grad.zero() g_solver.zero_grad() g_loss.forward() g_loss.backward(clear_buffer=True) # backward generator fake.backward(grad=None, clear_buffer=True) comm.all_reduced_solver_update(g_solver) # report iteration progress reporter() # report epoch progress show_epoch = itr // conf.show_interval if (itr % conf.show_interval) == 0: show_images = { "RealImages": real.data.get_data("r").transpose((0, 2, 3, 1)), "ObjectMask": colorizer(obj).astype(np.uint8), "GeneratedImage": fake.data.get_data("r").transpose( (0, 2, 3, 1)) } reporter.step(show_epoch, show_images) if (itr % conf.save_interval) == 0 and comm.rank == 0: nn.save_parameters( os.path.join(conf.save_path, 'param_{:03d}.h5'.format(show_epoch))) if comm.rank == 0: nn.save_parameters(os.path.join(conf.save_path, 'param_final.h5'))
def project(self, args): nn.set_auto_forward(True) # Input Image Variable image = Image.open(args.img_path).convert("RGB").resize( (256, 256), resample=Image.BILINEAR) image = np.array(image) / 255.0 image = np.transpose(image.astype(np.float32), (2, 0, 1)) image = np.expand_dims(image, 0) image = (image - 0.5) / (0.5) image = nn.Variable.from_numpy_array(image) # Get Latent Space Mean and Std. Dev. # Get Noise for B network z = F.randn(shape=(self.n_latent, self.latent_dim)).data w = mapping_network(z) latent_mean = F.mean(w, axis=0, keepdims=True) latent_std = F.pow_scalar(F.mean(F.pow_scalar(w - latent_mean, 2)), 0.5) # Get Noise noises = [F.randn(shape=(1, 1, 4, 4)).data] for res in self.generator.resolutions[1:]: for _ in range(2): shape = (1, 1, res, res) noises.append(F.randn(shape=shape).data) # Prepare parameters to be optimized latent_in = nn.Variable.from_numpy_array( latent_mean.data).apply(need_grad=True) noises = [ nn.Variable.from_numpy_array(n.data).apply(need_grad=True) for n in noises ] constant_bc = nn.parameter.get_parameter_or_create( name="G_synthesis/4x4/Const/const", shape=(1, 512, 4, 4)) constant_bc = F.broadcast(constant_bc, (1, ) + constant_bc.shape[1:]) pbar = tqdm(range(self.num_iters)) for i in pbar: t = i / self.num_iters self.set_lr(t) noise_strength = latent_std * 0.05 * max(0, 1 - t / 0.75)**2 latent_n = self.latent_noise(latent_in, noise_strength) gen_out = self.generator.synthesis([latent_n, latent_n], constant_bc, noises_in=noises) N, C, H, W = gen_out.shape factor = H // 256 gen_out = F.reshape( gen_out, (N, C, H // factor, factor, W // factor, factor)) gen_out = F.mean(gen_out, axis=(3, 5)) p_loss = F.sum(self.lpips_distance(image, gen_out)) n_loss = self.regularize_noise(noises) mse_loss = F.mean((gen_out - image)**2) loss = p_loss + self.n_c * n_loss + self.mse_c * mse_loss param_dict = {'latent': latent_in} for i in range(len(noises)): param_dict[f'noise_{i}'] = noises[i] self.solver.zero_grad() self.solver.set_parameters(param_dict, reset=False, retain_state=True) loss.backward() self.solver.update() noises = self.normalize_noises(noises) pbar.set_description(f'Loss: {loss.d} P Loss: {p_loss.d}') save_generations(image, 'original.png') gen_out = self.generator.synthesis([latent_n, latent_n], constant_bc, noises_in=noises) N, C, H, W = gen_out.shape factor = H // 256 gen_out = F.reshape(gen_out, (N, C, H // factor, factor, W // factor, factor), inplace=True) gen_out = F.mean(gen_out, axis=(3, 5)) save_generations(gen_out, 'projected.png') nn.save_parameters('projection_params.h5', param_dict)
def _transition(self, ecpoch_per_resolution): batch_size = self.di.batch_size resolution = self.gen.resolution_list[-1] phase = "{}to{}".format( self.gen.resolution_list[-2], self.gen.resolution_list[-1]) logger.info("phase : {}".format(phase)) kernel_size = self.resolution_list[-1] // resolution kernel = (kernel_size, kernel_size) total_itr = (self.di.size // batch_size + 1) * ecpoch_per_resolution global_itr = 1. alpha = global_itr / total_itr for epoch in range(ecpoch_per_resolution): logger.info("epoch : {}".format(epoch + 1)) itr = 0 current_epoch = self.di.epoch while self.di.epoch == current_epoch: img, _ = self.di.next() x = nn.Variable.from_numpy_array(img) z = F.randn(shape=(batch_size, self.n_latent, 1, 1)) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen.transition(z, alpha, test=True) y.unlinked() y.need_grad = False x_r = F.average_pooling(x, kernel=kernel) p_real = self.dis.transition(x_r, alpha) p_fake = self.dis.transition(y, alpha) loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.) + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight) if itr % self.n_critic + 1 == self.n_critic: with nn.parameter_scope("discriminator"): self.solver_dis.set_parameters(nn.get_parameters(), reset=False, retain_state=True) self.solver_dis.zero_grad() loss_dis.backward(clear_buffer=True) self.solver_dis.update() z = F.randn(shape=(batch_size, self.n_latent, 1, 1)) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen.transition(z, alpha, test=False) p_fake = self.dis.transition(y, alpha) loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2)) with nn.parameter_scope("generator"): self.solver_gen.set_parameters( nn.get_parameters(), reset=False, retain_state=True) self.solver_gen.zero_grad() loss_gen.backward(clear_buffer=True) self.solver_gen.update() itr += 1 global_itr += 1. alpha = global_itr / total_itr if epoch % self.save_image_interval + 1 == self.save_image_interval: z = nn.Variable.from_numpy_array(self.z_test) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen.transition(z, alpha) img_name = "phase_{}_epoch_{}".format(phase, epoch + 1) self.monitor_image_tile.add( img_name, F.unpooling(y, kernel=kernel))
def _train(self, ecpoch_per_resolution, each_save=False): batch_size = self.di.batch_size resolution = self.gen.resolution_list[-1] logger.info("phase : {}".format(resolution)) kernel_size = self.resolution_list[-1] // resolution kernel = (kernel_size, kernel_size) img_name = "original_phase_{}".format(resolution) img, _ = self.di.next() self.monitor_image_tile.add(img_name, img) for epoch in range(ecpoch_per_resolution): logger.info("epoch : {}".format(epoch + 1)) itr = 0 current_epoch = self.di.epoch while self.di.epoch == current_epoch: img, _ = self.di.next() x = nn.Variable.from_numpy_array(img) z = F.randn(shape=(batch_size, self.n_latent, 1, 1)) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen(z, test=True) y.unlinked() y.need_grad = False x_r = F.average_pooling(x, kernel=kernel) p_real = self.dis(x_r) p_fake = self.dis(y) p_real.persistent, p_fake.persistent = True, True loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.) + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight) loss_dis.persistent = True if itr % self.n_critic + 1 == self.n_critic: with nn.parameter_scope("discriminator"): self.solver_dis.set_parameters(nn.get_parameters(), reset=False, retain_state=True) self.solver_dis.zero_grad() loss_dis.backward(clear_buffer=True) self.solver_dis.update() z = F.randn(shape=(batch_size, self.n_latent, 1, 1)) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen(z, test=False) p_fake = self.dis(y) p_fake.persistent = True loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2.)) loss_gen.persistent = True with nn.parameter_scope("generator"): self.solver_gen.set_parameters(nn.get_parameters(), reset=False, retain_state=True) self.solver_gen.zero_grad() loss_gen.backward(clear_buffer=True) self.solver_gen.update() # Monitor self.monitor_p_real.add( self.global_itr, p_real.d.copy().mean()) self.monitor_p_fake.add( self.global_itr, p_fake.d.copy().mean()) self.monitor_loss_dis.add(self.global_itr, loss_dis.d.copy()) self.monitor_loss_gen.add(self.global_itr, loss_gen.d.copy()) itr += 1 self.global_itr += 1 if epoch % self.save_image_interval + 1 == self.save_image_interval: z = nn.Variable.from_numpy_array(self.z_test) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen(z, test=True) img_name = "phase_{}_epoch_{}".format(resolution, epoch + 1) self.monitor_image_tile.add( img_name, F.unpooling(y, kernel=kernel)) if each_save: self.gen.save_parameters(self.monitor_path, "Gen_phase_{}_epoch_{}".format( self.resolution_list[-1], epoch+1)) self.dis.save_parameters(self.monitor_path, "Dis_phase_{}_epoch_{}".format( self.resolution_list[-1], epoch+1))
def infer(self, mels, sigma=0.9): r"""Returns the generated audio. Args: mels (nn.Variable): Inputs containing mel-spectrograms of shape(B, n_mels, Ty). Defaults to None. If None, the mel spectrograms are infferred from data. sigma (float, optional): Sigma used to infer audio. Defaults to 0.9. Returns: nn.Variable: A synthetic audio. """ hp = self.hparams with nn.parameter_scope('', self.parameter_scope): # Upsample spectrogram to size of audio with nn.parameter_scope('upsample'): with nn.parameter_scope('deconv'): mels = PF.deconvolution(mels, hp.n_mels, kernel=(1024, ), stride=(256, )) # cutout conv artifacts mels = mels[..., :-(1024 - 256)] # kernel - stride # transforming to correct shape mels = F.reshape(mels, mels.shape[:2] + (-1, hp.n_samples_per_group)) mels = F.transpose(mels, (0, 2, 1, 3)) mels = F.reshape(mels, mels.shape[:2] + (-1, )) # (B, n_mels * n_groups, L/n_groups) mels = F.transpose(mels, (0, 2, 1)) wave = F.randn(shape=(mels.shape[0], self.n_remaining_channels, mels.shape[2])) * sigma for k in reversed(range(hp.n_flows)): n_half = wave.shape[1] // 2 audio_0 = wave[:, :n_half, :] audio_1 = wave[:, n_half:, :] with nn.parameter_scope(f'wn_{k}'): output = getattr(self, f'WN_{k}')(audio_0, mels) s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = (audio_1 - b) / F.exp(s) wave = F.concatenate(audio_0, audio_1, axis=1) wave = invertible_conv(wave, reverse=True, rng=self.rng, scope=f'inv_{k}') if k % hp.n_early_every == 0 and k > 0: z = F.randn(shape=(mels.shape[0], hp.n_early_size, mels.shape[2])) wave = F.concatenate(sigma * z, wave, axis=1) wave = F.transpose(wave, (0, 2, 1)) wave = F.reshape(wave, (wave.shape[0], -1)) return wave
def generate(args): # Load model nn.load_parameters(args.model_load_path) # Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) nn.set_default_context(ctx) # Input b, c, h, w = 1, 3, args.image_size, args.image_size x_real_a = nn.Variable([b, c, h, w]) x_real_b = nn.Variable([b, c, h, w]) one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5) # Model maps = args.maps # content/style (domain A) x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a") x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a") # content/style (domain B) x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b") x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b") # generate over domains and reconstruction of content and style (domain A) z_style_a = F.randn( shape=x_style_a.shape) if not args.example_guided else x_style_a z_style_a = z_style_a.apply(persistent=True) x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a") # generate over domains and reconstruction of content and style (domain B) z_style_b = F.randn( shape=x_style_b.shape) if not args.example_guided else x_style_b z_style_b = z_style_b.apply(persistent=True) x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b") # Monitor suffix = "Stochastic" if not args.example_guided else "Example-guided" monitor = Monitor(args.monitor_path) monitor_image_a = MonitorImage("Fake Image B to A {} Valid".format(suffix), monitor, interval=1) monitor_image_b = MonitorImage("Fake Image A to B {} Valid".format(suffix), monitor, interval=1) # DataIterator di_a = munit_data_iterator(args.img_path_a, args.batch_size) di_b = munit_data_iterator(args.img_path_b, args.batch_size) # Generate all # generate (A -> B) if args.example_guided: x_real_b.d = di_b.next()[0] for i in range(di_a.size): x_real_a.d = di_a.next()[0] images = [] images.append(x_real_a.d.copy()) for _ in range(args.num_repeats): x_fake_b.forward(clear_buffer=True) images.append(x_fake_b.d.copy()) monitor_image_b.add(i, np.concatenate(images, axis=3)) # generate (B -> A) if args.example_guided: x_real_a.d = di_a.next()[0] for i in range(di_b.size): x_real_b.d = di_b.next()[0] images = [] images.append(x_real_b.d.copy()) for _ in range(args.num_repeats): x_fake_a.forward(clear_buffer=True) images.append(x_fake_a.d.copy()) monitor_image_a.add(i, np.concatenate(images, axis=3))
def synthesis(self, w_mixed, constant_bc, seed=-1, noises_in=None): batch_size = w_mixed.shape[0] if noises_in is None: noise = F.randn(shape=(batch_size, 1, 4, 4), seed=seed) else: noise = noises_in[0] w = F.reshape(F.slice(w_mixed, start=(0, 0, 0), stop=(w_mixed.shape[0], 1, w_mixed.shape[2]), step=(1, 1, 1)), (w_mixed.shape[0], w_mixed.shape[2]), inplace=False) h = styled_conv_block(constant_bc, w, noise, res=self.resolutions[0], outmaps=self.feature_map_dim, namescope="Conv") torgb = styled_conv_block(h, w, noise=None, res=self.resolutions[0], outmaps=3, inmaps=self.feature_map_dim, kernel_size=1, pad_size=0, demodulate=False, namescope="ToRGB", act=F.identity) # initial feature maps outmaps = self.feature_map_dim inmaps = self.feature_map_dim downsize_index = 4 if self.resolutions[-1] in [512, 1024] else 3 # resolution 8 x 8 - 1024 x 1024 for i in range(1, len(self.resolutions)): i1 = (2 + i) * 2 - 5 i2 = (2 + i) * 2 - 4 i3 = (2 + i) * 2 - 3 w_ = F.reshape(F.slice(w_mixed, start=(0, i1, 0), stop=(w_mixed.shape[0], i1 + 1, w_mixed.shape[2]), step=(1, 1, 1)), w.shape, inplace=False) if i > downsize_index: outmaps = outmaps // 2 curr_shape = (batch_size, 1, self.resolutions[i], self.resolutions[i]) if noises_in is None: noise = F.randn(shape=curr_shape, seed=seed) else: noise = noises_in[2 * i - 1] h = styled_conv_block(h, w_, noise, res=self.resolutions[i], outmaps=outmaps, inmaps=inmaps, kernel_size=3, up=True, namescope="Conv0_up") w_ = F.reshape(F.slice(w_mixed, start=(0, i2, 0), stop=(w_mixed.shape[0], i2 + 1, w_mixed.shape[2]), step=(1, 1, 1)), w.shape, inplace=False) if i > downsize_index: inmaps = inmaps // 2 if noises_in is None: noise = F.randn(shape=curr_shape, seed=seed) else: noise = noises_in[2 * i] h = styled_conv_block(h, w_, noise, res=self.resolutions[i], outmaps=outmaps, inmaps=inmaps, kernel_size=3, pad_size=1, namescope="Conv1") w_ = F.reshape(F.slice(w_mixed, start=(0, i3, 0), stop=(w_mixed.shape[0], i3 + 1, w_mixed.shape[2]), step=(1, 1, 1)), w.shape, inplace=False) curr_torgb = styled_conv_block(h, w_, noise=None, res=self.resolutions[i], outmaps=3, inmaps=inmaps, kernel_size=1, pad_size=0, demodulate=False, namescope="ToRGB", act=F.identity) torgb = F.add2(curr_torgb, upsample_2d(torgb, k=[1, 3, 3, 1])) return torgb
def train(args): # Create Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank device_id = mpi_local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Input b, c, h, w = args.batch_size, 3, args.image_size, args.image_size x_real_a = nn.Variable([b, c, h, w]) x_real_b = nn.Variable([b, c, h, w]) # Model # workaround for starting with the same model among devices. np.random.seed(412) maps = args.maps # within-domain reconstruction (domain A) x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a") x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a") x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a") # within-domain reconstruction (domain B) x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b") x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b") x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b") # generate over domains and reconstruction of content and style (domain A) z_style_a = F.randn(shape=x_style_a.shape) x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a") x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a") x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a") # generate over domains and reconstruction of content and style (domain B) z_style_b = F.randn(shape=x_style_b.shape) x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b") x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b") x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b") # discriminate (domain A) p_x_fake_a_list = discriminators(x_fake_a) p_x_real_a_list = discriminators(x_real_a) p_x_fake_b_list = discriminators(x_fake_b) p_x_real_b_list = discriminators(x_real_b) # Loss # within-domain reconstruction loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True) loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True) # content and style reconstruction loss_recon_x_style_a = recon_loss(x_style_rec_a, z_style_a).apply(persistent=True) loss_recon_x_content_b = recon_loss(x_content_rec_b, x_content_b).apply(persistent=True) loss_recon_x_style_b = recon_loss(x_style_rec_b, z_style_b).apply(persistent=True) loss_recon_x_content_a = recon_loss(x_content_rec_a, x_content_a).apply(persistent=True) # adversarial def f(x, y): return x + y loss_gen_a = reduce(f, [lsgan_loss(p_f) for p_f in p_x_fake_a_list]).apply(persistent=True) loss_dis_a = reduce(f, [ lsgan_loss(p_f, p_r) for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list) ]).apply(persistent=True) loss_gen_b = reduce(f, [lsgan_loss(p_f) for p_f in p_x_fake_b_list]).apply(persistent=True) loss_dis_b = reduce(f, [ lsgan_loss(p_f, p_r) for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list) ]).apply(persistent=True) # loss for generator-related models loss_gen = loss_gen_a + loss_gen_b \ + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \ + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \ + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b) # loss for discriminators loss_dis = loss_dis_a + loss_dis_b # Solver lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2 # solver for generator-related models solver_gen = S.Adam(lr_g, beta1, beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) # solver for discriminators solver_dis = S.Adam(lr_d, beta1, beta2) with nn.parameter_scope("discriminators"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor monitor = Monitor(args.monitor_path) # time monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) # reconstruction monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A", monitor, interval=10) monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B", monitor, interval=10) monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A", monitor, interval=10) monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B", monitor, interval=10) monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A", monitor, interval=10) monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B", monitor, interval=10) # adversarial monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10) monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10) monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10) monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10) monitor_losses = [ # reconstruction (monitor_loss_recon_x_a, loss_recon_x_a), (monitor_loss_recon_x_content_b, loss_recon_x_content_b), (monitor_loss_recon_x_style_a, loss_recon_x_style_a), (monitor_loss_recon_x_b, loss_recon_x_b), (monitor_loss_recon_x_content_a, loss_recon_x_content_a), (monitor_loss_recon_x_style_b, loss_recon_x_style_b), # adaversarial (monitor_loss_gen_a, loss_gen_a), (monitor_loss_dis_a, loss_dis_a), (monitor_loss_gen_b, loss_gen_b), (monitor_loss_dis_b, loss_dis_b) ] # image monitor_image_a = MonitorImage("Fake Image B to A Train", monitor, interval=1) monitor_image_b = MonitorImage("Fake Image A to B Train", monitor, interval=1) monitor_images = [ (monitor_image_a, x_fake_a), (monitor_image_b, x_fake_b), ] # DataIterator rng_a = np.random.RandomState(device_id) rng_b = np.random.RandomState(device_id + n_devices) di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a) di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b) # Train for i in range(args.max_iter // n_devices): ii = i * n_devices # Train generator-related models x_data_a, x_data_b = di_a.next()[0], di_b.next()[0] x_real_a.d, x_real_b.d = x_data_a, x_data_b solver_gen.zero_grad() loss_gen.forward(clear_no_need_grad=True) loss_gen.backward(clear_buffer=True) comm.all_reduce([w.grad for w in params_gen.values()]) solver_gen.weight_decay(args.weight_decay_rate) solver_gen.update() # Train discriminators x_data_a, x_data_b = di_a.next()[0], di_b.next()[0] x_real_a.d, x_real_b.d = x_data_a, x_data_b x_fake_a.need_grad, x_fake_b.need_grad = False, False solver_dis.zero_grad() loss_dis.forward(clear_no_need_grad=True) loss_dis.backward(clear_buffer=True) comm.all_reduce([w.grad for w in params_dis.values()]) solver_dis.weight_decay(args.weight_decay_rate) solver_dis.update() x_fake_a.need_grad, x_fake_b.need_grad = True, True # LR schedule if (i + 1) % (args.lr_decay_at_every // n_devices) == 0: lr_d = solver_dis.learning_rate() * args.lr_decay_rate lr_g = solver_gen.learning_rate() * args.lr_decay_rate solver_dis.set_learning_rate(lr_d) solver_gen.set_learning_rate(lr_g) if mpi_local_rank == 0: # Monitor monitor_time.add(ii) for mon, loss in monitor_losses: mon.add(ii, loss.d) # Save if (i + 1) % (args.model_save_interval // n_devices) == 0: for mon, x in monitor_images: mon.add(ii, x.d) nn.save_parameters( os.path.join(args.monitor_path, "param_{:05d}.h5".format(i))) if mpi_local_rank == 0: # Monitor for mon, loss in monitor_losses: mon.add(ii, loss.d) # Save for mon, x in monitor_images: mon.add(ii, x.d) nn.save_parameters( os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))
def sample_loop(self, model, shape, sampler, noise=None, dump_interval=-1, progress=False, without_auto_forward=False): """ Iteratively Sample data from model from t=T to t=0. T is specified as the length of betas given to __init__(). Args: model (collable): A callable that takes x_t and t and predict noise (and sigma related parameters). shape (list like object): A data shape. sampler (callable): A function to sample x_{t-1} given x_{t} and t. Typically, self.p_sample or self.ddim_sample. noise (collable): A noise generator. If None, F.randn(shape) will be used. interval (int): If > 0, all intermediate results at every `interval` step will be returned as a list. e.g. if interval = 10, the predicted results at {10, 20, 30, ...} will be returned. progress (bool): If True, tqdm will be used to show the sampling progress. Returns: - x_0 (nn.Variable): the final sampled result of x_0 - samples (a list of nn.Variable): the sampled results at every `interval` - pred_x_starts (a list of nn.Variable): the predicted x_0 from each x_t at every `interval`: """ T = self.num_timesteps indices = list(range(T))[::-1] samples = [] pred_x_starts = [] if progress: from tqdm.auto import tqdm indices = tqdm(indices) if without_auto_forward: if noise is None: noise = np.random.randn(*shape) else: assert isinstance(noise, np.ndarray) assert noise.shape == shape x_t = nn.Variable.from_numpy_array(noise) t = nn.Variable.from_numpy_array([T - 1 for _ in range(shape[0])]) # build graph y, pred_x_start = sampler(model, x_t, t) up_x_t = F.assign(x_t, y) up_t = F.assign(t, t - 1) update = F.sink(up_x_t, up_t) cnt = 0 for step in indices: y.forward(clear_buffer=True) update.forward(clear_buffer=True) cnt += 1 if dump_interval > 0 and cnt % dump_interval == 0: samples.append((step, y.d.copy())) pred_x_starts.append((step, pred_x_start.d.copy())) else: with nn.auto_forward(): if noise is None: x_t = F.randn(shape=shape) else: assert isinstance(noise, np.ndarray) assert noise.shape == shape x_t = nn.Variable.from_numpy_array(noise) cnt = 0 for step in indices: t = F.constant(step, shape=(shape[0], )) x_t, pred_x_start = sampler( model, x_t, t, no_noise=step == 0) cnt += 1 if dump_interval > 0 and cnt % dump_interval == 0: samples.append((step, x_t.d.copy())) pred_x_starts.append((step, pred_x_start.d.copy())) assert x_t.shape == shape return x_t.d.copy(), samples, pred_x_starts