Пример #1
0
    def style_mixing(self, test_config, args):

        from nnabla.utils.image_utils import imsave, imresize

        print('Testing style mixing of generation...')

        z1 = F.randn(shape=(args.batch_size_A, test_config['latent_dim']),
                     seed=args.seed_1[0]).data
        z2 = F.randn(shape=(args.batch_size_B, test_config['latent_dim']),
                     seed=args.seed_2[0]).data

        nn.set_auto_forward(True)

        mix_image_stacks = []
        for i in range(args.batch_size_A):
            image_column = []
            for j in range(args.batch_size_B):
                style_noises = [
                    F.reshape(z1[i], (1, 512)),
                    F.reshape(z2[j], (1, 512))
                ]
                rgb_output = self.generator(
                    1,
                    style_noises,
                    test_config['truncation_psi'],
                    mixing_layer_index=test_config['mix_after'])
                image = save_generations(rgb_output, None, return_images=True)
                image_column.append(image[0])
            image_column = np.concatenate([image for image in image_column],
                                          axis=1)
            mix_image_stacks.append(image_column)
        mix_image_stacks = np.concatenate(
            [image for image in mix_image_stacks], axis=2)

        style_noises = [z1, z1]
        rgb_output = self.generator(args.batch_size_A, style_noises,
                                    test_config['truncation_psi'])
        image_A = save_generations(rgb_output, None, return_images=True)
        image_A = np.concatenate([image for image in image_A], axis=2)

        style_noises = [z2, z2]
        rgb_output = self.generator(args.batch_size_B, style_noises,
                                    test_config['truncation_psi'])
        image_B = save_generations(rgb_output, None, return_images=True)
        image_B = np.concatenate([image for image in image_B], axis=1)

        top_image = 255 * np.ones(rgb_output[0].shape).astype(np.uint8)

        top_image = np.concatenate((top_image, image_A), axis=2)
        grid_image = np.concatenate((image_B, mix_image_stacks), axis=2)
        grid_image = np.concatenate((top_image, grid_image), axis=1)

        filename = os.path.join(self.results_dir, 'style_mix.png')
        imsave(filename,
               imresize(grid_image, (1024, 1024), channel_first=True),
               channel_first=True)
        print(f'Output saved as {filename}')
Пример #2
0
def ce_loss_with_uncertainty(ctx, pred, y_l, log_var):
    r = F.randn(0., 1., log_var.shape)
    r = F.pow_scalar(F.exp(log_var), 0.5) * r
    h = pred + r
    with nn.context_scope(ctx):
        loss_ce = F.mean(F.softmax_cross_entropy(h, y_l))
    return loss_ce
Пример #3
0
def ce_loss_with_uncertainty(ctx, pred, y_l, log_var):
    r = F.randn(0., 1., log_var.shape)
    r = F.pow_scalar(F.exp(log_var), 0.5) * r
    h = pred + r
    with nn.context_scope(ctx):
        loss_ce = F.mean(F.softmax_cross_entropy(h, y_l))
    return loss_ce
Пример #4
0
def test_randn_forward_backward(seed, ctx, func_name, mu, sigma, shape):
    with nn.context_scope(ctx):
        o = F.randn(mu, sigma, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    if o.size >= 10000:
        est_mu = o.d.mean()
        est_sigma = o.d.std()
        np.isclose(est_mu, mu, atol=sigma)
        np.isclose(est_sigma, sigma, atol=sigma)
    else:
        data = []
        for i in range(10000):
            o.forward()
            data += [o.d.copy()]
        est_mu = np.mean(np.array(data))
        est_sigma = np.std(np.array(data))
        np.isclose(est_mu, mu, atol=sigma)
        np.isclose(est_sigma, sigma, atol=sigma)

    # Checking recomputation
    func_args = [mu, sigma, shape, seed]
    recomputation_test(rng=None,
                       func=F.randn,
                       vinputs=[],
                       func_args=func_args,
                       func_kwargs={},
                       ctx=ctx)
Пример #5
0
def sample_noise(inpt_size, out_size):
    _f = lambda x: F.sign(x) * F.pow_scalar(F.abs(x), 0.5)
    noise = _f(F.randn(shape=(inpt_size + out_size, )))
    eps_w = F.batch_matmul(F.reshape(noise[:inpt_size], (1, -1)),
                           F.reshape(noise[inpt_size:], (1, -1)), True)
    eps_b = noise[inpt_size:]
    return eps_w, eps_b
Пример #6
0
    def __init__(self, bs, **kwargs):
        noise = F.randn(mu=0, sigma=kwargs['sigma_affine'], shape=(bs, 2, 3))
        self.theta = noise + \
            nn.Variable.from_numpy_array(
                np.array([[[1., 0., 0.], [0., 1., 0.]]]))
        self.bs = bs

        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
            self.tps = True
            self.control_points = make_coordinate_grid(
                (kwargs['points_tps'], kwargs['points_tps']))
            self.control_points = F.reshape(
                self.control_points, (1,) + self.control_points.shape)
            self.control_params = F.randn(
                mu=0, sigma=kwargs['sigma_tps'], shape=(bs, 1, kwargs['points_tps'] ** 2))
        else:
            self.tps = False
Пример #7
0
 def graph(x1):
     x1 = F.identity(x1).apply(recompute=True)
     x2 = F.randn(shape=x1.shape, seed=123).apply(recompute=True)
     x3 = F.rand(shape=x1.shape, seed=456).apply(recompute=True)
     y = F.mul2(x1, x2).apply(recompute=True)
     y = F.mul2(y, x3).apply(recompute=True)
     y = F.identity(y)
     return y
Пример #8
0
    def build_static_graph(self):
        real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size,
                                      self.img_size))
        noises = [
            F.randn(shape=(self.batch_size, self.config['latent_dim']))
            for _ in range(2)
        ]

        if self.config['regularize_gen']:
            fake_img, dlatents = self.generator(self.batch_size,
                                                noises,
                                                return_latent=True)
        else:
            fake_img = self.generator(self.batch_size, noises)
        fake_img_test = self.generator_ema(self.batch_size, noises)

        gen_loss = gen_nonsaturating_loss(self.discriminator(fake_img))

        fake_disc_out = self.discriminator(fake_img)
        real_disc_out = self.discriminator(real_img)
        disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out)

        var_name_list = [
            'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss',
            'fake_disc_out', 'real_disc_out', 'fake_img_test'
        ]
        var_list = [
            real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out,
            real_disc_out, fake_img_test
        ]

        if self.config['regularize_gen']:
            dlatents.need_grad = True
            mean_path_length = nn.Variable()
            pl_reg, path_mean, _ = gen_path_regularize(
                fake_img=fake_img,
                latents=dlatents,
                mean_path_length=mean_path_length)
            path_mean_update = F.assign(mean_path_length, path_mean)
            path_mean_update.name = 'path_mean_update'
            pl_reg += 0 * path_mean_update
            gen_loss_reg = gen_loss + pl_reg
            var_name_list.append('gen_loss_reg')
            var_list.append(gen_loss_reg)

        if self.config['regularize_disc']:
            real_img.need_grad = True
            real_disc_out = self.discriminator(real_img)
            disc_loss_reg = disc_loss + self.config[
                'r1_coeff'] * 0.5 * disc_r1_loss(
                    real_disc_out, real_img) * self.config['disc_reg_step']
            real_img.need_grad = False
            var_name_list.append('disc_loss_reg')
            var_list.append(disc_loss_reg)

        Parameters = namedtuple('Parameters', var_name_list)
        self.parameters = Parameters(*var_list)
Пример #9
0
 def generate_z_anchor(self):
     z_anchor_list = []
     for _ in range(2):
         z_anchor_var = F.gather(
             self.init_z_var, combination(
                 self.n_train, self.batch_size)) + F.randn(
                     sigma=self.anch_std,
                     shape=(self.batch_size, self.latent_dim))
         z_anchor_list.append(z_anchor_var)
     return z_anchor_list
Пример #10
0
    def sample(self, mu, logvar):
        r"""Samples from a Gaussian distribution.

        Args:
            mu (nn.Variable): Mean of the distribution of shape (B, D, 1).
            logvar (nn.Variable): Log variance of the distribution of
                shape (B, D, 1).

        Returns:
            nn.Variable: A sample.
        """
        if self.training:
            eps = F.randn(shape=mu.shape)
            return mu + F.exp(0.5 * logvar) * eps
        return mu
Пример #11
0
def test_randn_forward_backward(seed, ctx, func_name, mu, sigma, shape):
    with nn.context_scope(ctx):
        o = F.randn(mu, sigma, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    if o.size >= 10000:
        est_mu = o.d.mean()
        est_sigma = o.d.std()
        np.isclose(est_mu, mu, atol=sigma)
        np.isclose(est_sigma, sigma, atol=sigma)
    else:
        data = []
        for i in range(10000):
            o.forward()
            data += [o.d.copy()]
        est_mu = np.mean(np.array(data))
        est_sigma = np.std(np.array(data))
        np.isclose(est_mu, mu, atol=sigma)
        np.isclose(est_sigma, sigma, atol=sigma)
Пример #12
0
def gen_path_regularize(fake_img,
                        latents,
                        mean_path_length,
                        decay=0.01,
                        pl_weight=2.0):

    noise = F.randn(shape=fake_img.shape) / \
                    np.sqrt(fake_img.shape[2]*fake_img.shape[3])

    gradient = nn.grad([F.sum(fake_img * noise)], [latents])[0]
    path_lengths = F.mean(F.sum(F.pow_scalar(gradient, 2), axis=1), axis=0)
    path_lengths = F.pow_scalar(path_lengths, 0.5)

    path_mean = mean_path_length + decay * \
        (F.mean(path_lengths) - mean_path_length)

    path_penalty = F.mean(
        F.pow_scalar(path_lengths - F.reshape(path_mean, (1, ), inplace=False),
                     1))
    return path_penalty * pl_weight, path_mean, path_lengths
Пример #13
0
def test_randn_forward_backward(seed, ctx, func_name, mu, sigma, shape):
    with nn.context_scope(ctx):
        o = F.randn(mu, sigma, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    if o.size >= 10000:
        est_mu = o.d.mean()
        est_sigma = o.d.std()
        np.isclose(est_mu, mu, atol=sigma)
        np.isclose(est_sigma, sigma, atol=sigma)
    else:
        data = []
        for i in range(10000):
            o.forward()
            data += [o.d.copy()]
        est_mu = np.mean(np.array(data))
        est_sigma = np.std(np.array(data))
        np.isclose(est_mu, mu, atol=sigma)
        np.isclose(est_sigma, sigma, atol=sigma)
Пример #14
0
    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data (t == 0 means diffused for 1 step), which samples from q(x_t | x_0).
        xt = sqrt(cumprod(alpha_0, ..., alpha_t)) * x_0 + sqrt(1 - cumprod(alpha_0, ..., alpha_t)) * epsilon

        Args:
            x_start (nn.Variable): The (B, C, ...) tensor of x_0.
            t (nn.Variable): A 1-D tensor of timesteps.

        Return:
            x_t (nn.Variable): 
                The (B, C, ...) tensor of x_t.
                Each sample x_t[i] corresponds to the noisy image at timestep t[i] constructed from x_start[i].
        """
        if noise is None:
            noise = F.randn(shape=x_start.shape)
        assert noise.shape == x_start.shape
        return (
            self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            self._extract(self.sqrt_one_minus_alphas_cumprod,
                          t, x_start.shape) * noise
        )
Пример #15
0
    def define_network(self):

        if self.use_inst:
            obj_onehot, bm = encode_inputs(self.ist_mask,
                                           self.obj_mask,
                                           n_ids=self.conf.n_class)

            mask = F.concatenate(obj_onehot, bm, axis=1)
        else:
            om = self.obj_mask
            if len(om.shape) == 3:
                om = F.reshape(om, om.shape + (1, ))
            obj_onehot = F.one_hot(om, shape=(self.conf.n_class, ))
            mask = F.transpose(obj_onehot, (0, 3, 1, 2))

        generator = SpadeGenerator(self.conf.g_ndf,
                                   image_shape=self.conf.image_shape)
        z = F.randn(shape=(self.conf.batch_size, self.conf.z_dim))
        fake = generator(z, mask)

        # Pixel intensities of fake are [-1, 1]. Rescale it to [0, 1]
        fake = (fake + 1) / 2

        return fake
Пример #16
0
def _smoothing_target(policy_tp1, sigma, c):
    noise_shape = policy_tp1.shape
    smoothing_noise = F.randn(sigma=sigma, shape=noise_shape)
    clipped_noise = clip_by_value(smoothing_noise, -c, c)
    return clip_by_value(policy_tp1 + clipped_noise, -1.0, 1.0)
Пример #17
0
def volumetric_rendering(radiance_field,
                         ray_origins,
                         depth_values,
                         return_weights=False,
                         white_bkgd=False,
                         raw_noise_std=0.0,
                         apply_act=False):
    """Integration of volumetric rendering

    Args:
        radiance_field (nn.Variable or nn.NdArray): Shape is (height, width, num_samples, 4). 
        radiance_field[:,:,:,:3] correspond to rgb value at each sampled point while radiance_field[:,:,:,-1] refers to color density.
        ray_origins (nn.Variable or nn.NdArray): Shape is (height, width, 3)
        depth_values (nn.Variable or nn.NdArray): Shape is (num_samples, 1) or (height, width, num_samples) 
        return_weights (bool, optional): Set to true if the coefficients of the volumetric integration sum are to be returned . Defaults to False.

    Returns:
        rgb_map (nn.Variable or nn.NdArray): Shape is (height, width, 3)
        rgb_map (nn.Variable or nn.NdArray): Shape is (height, width, 1)
    """
    if apply_act:
        sigma = F.relu(radiance_field[..., 3])
        rgb = F.sigmoid(radiance_field[..., :3])
    else:
        sigma = radiance_field[..., 3]
        rgb = radiance_field[..., :3]

    if raw_noise_std > 0.0:
        noise = F.randn(shape=sigma.shape)
        sigma += (noise * raw_noise_std)

    if depth_values.ndim == 2:
        distances = depth_values[:, 1:] - depth_values[:, :-1]
        distances = F.concatenate(distances,
                                  F.constant(1e2,
                                             shape=depth_values.shape[:-1] +
                                             (1, )),
                                  axis=-1)
        alpha = 1. - F.exp(-sigma * distances)
        weights = alpha * F.cumprod(1 - alpha + 1e-10, axis=-1, exclusive=True)
        rgb_map = F.sum(weights[..., None] * rgb, axis=-2)
        depth_map = F.sum(weights * depth_values, axis=-1)
        acc_map = F.sum(weights, axis=-1)
    else:
        distances = depth_values[:, :, 1:] - depth_values[:, :, :-1]
        distances = F.concatenate(distances,
                                  F.constant(1e10,
                                             shape=depth_values.shape[:-1] +
                                             (1, )),
                                  axis=-1)
        alpha = 1. - F.exp(-sigma * distances)
        rgb_map = F.sum(weights[..., None] * rgb, axis=rgb.ndim - 2)
        depth_map = F.sum(weights * depth_values, axis=1)
        acc_map = F.sum(weights, axis=-1)

    if white_bkgd:
        rgb_map = rgb_map + (1. - acc_map[..., None])

    if return_weights:
        disp_map = 1.0 / \
            F.maximum2(F.constant(1e-10, depth_map.shape), depth_map / acc_map)
        return rgb_map, depth_map, acc_map, disp_map, weights

    return rgb_map, depth_map, acc_map
Пример #18
0
 def generate_z_normal(self):
     z_normal_list = [
         F.randn(shape=(self.batch_size, self.latent_dim)) for _ in range(2)
     ]
     return z_normal_list
Пример #19
0
def generate_data(args):

    if not os.path.isfile(os.path.join(args.weights_path, 'gen_params.h5')):
        os.makedirs(args.weights_path, exist_ok=True)
        print(
            "Downloading the pretrained tf-converted weights. Please wait...")
        url = "https://nnabla.org/pretrained-models/nnabla-examples/GANs/stylegan2/styleGAN2_G_params.h5"
        from nnabla.utils.data_source_loader import download
        download(url, os.path.join(args.weights_path, 'gen_params.h5'), False)

    nn.load_parameters(os.path.join(args.weights_path, 'gen_params.h5'))
    print('Loaded pretrained weights from tensorflow!')

    os.makedirs(args.save_image_path, exist_ok=True)

    batches = [
        args.batch_size for _ in range(args.num_images // args.batch_size)
    ]
    if args.num_images % args.batch_size != 0:
        batches.append(args.num_images -
                       (args.num_images // args.batch_size) * args.batch_size)

    for idx, batch_size in enumerate(batches):
        z = [
            F.randn(shape=(batch_size, 512)).data,
            F.randn(shape=(batch_size, 512)).data
        ]

        for i in range(len(z)):
            z[i] = F.div2(
                z[i],
                F.pow_scalar(F.add_scalar(
                    F.mean(z[i]**2., axis=1, keepdims=True), 1e-8),
                             0.5,
                             inplace=True))

        # get latent code
        w = [mapping_network(z[0], outmaps=512, num_layers=8)]
        w += [mapping_network(z[1], outmaps=512, num_layers=8)]

        # truncation trick
        dlatent_avg = nn.parameter.get_parameter_or_create(name="dlatent_avg",
                                                           shape=(1, 512))
        w = [lerp(dlatent_avg, _, 0.7) for _ in w]

        # Load direction
        if not args.face_morph:
            attr_delta = nn.NdArray.from_numpy_array(
                np.load(args.attr_delta_path))
            attr_delta = F.reshape(attr_delta[0], (1, -1))
            w_plus = [w[0] + args.coeff * attr_delta, w[1]]
            w_minus = [w[0] - args.coeff * attr_delta, w[1]]
        else:
            w_plus = [w[0], w[0]]  # content
            w_minus = [w[1], w[1]]  # style

        constant_bc = nn.parameter.get_parameter_or_create(
            name="G_synthesis/4x4/Const/const", shape=(1, 512, 4, 4))
        constant_bc = F.broadcast(constant_bc,
                                  (batch_size, ) + constant_bc.shape[1:])

        gen_plus = synthesis(w_plus, constant_bc, noise_seed=100, mix_after=8)
        gen_minus = synthesis(w_minus,
                              constant_bc,
                              noise_seed=100,
                              mix_after=8)
        gen = synthesis(w, constant_bc, noise_seed=100, mix_after=8)

        image_plus = convert_images_to_uint8(gen_plus, drange=[-1, 1])
        image_minus = convert_images_to_uint8(gen_minus, drange=[-1, 1])
        image = convert_images_to_uint8(gen, drange=[-1, 1])

        for j in range(batch_size):
            filepath = os.path.join(args.save_image_path,
                                    f'image_{idx*batch_size+j}')
            imsave(f'{filepath}_o.png', image_plus[j], channel_first=True)
            imsave(f'{filepath}_y.png', image_minus[j], channel_first=True)
            imsave(f'{filepath}.png', image[j], channel_first=True)
            print(f"Genetated. Saved {filepath}")
Пример #20
0
def vae(x, shape_z, test=False):
    """
    Function for calculate Elbo(evidence lowerbound) loss.
    This sample is a Bernoulli generator version.

    Args:
        x(`~nnabla.Variable`): N-D array
        shape_z(tuple of int): size of z
        test : True=train, False=test

    Returns:
        ~nnabla.Variable: Elbo loss

    """

    #############################################
    # Encoder of 2 fully connected layers       #
    #############################################

    # Normalize input
    xa = x / 256.
    batch_size = x.shape[0]

    # 2 fully connected layers, and Elu replaced from original Softplus.
    h = F.elu(PF.affine(xa, (500, ), name='fc1'))
    h = F.elu(PF.affine(h, (500, ), name='fc2'))

    # The outputs are the parameters of Gauss probability density.
    mu = PF.affine(h, shape_z, name='fc_mu')
    logvar = PF.affine(h, shape_z, name='fc_logvar')
    sigma = F.exp(0.5 * logvar)

    # The prior variable and the reparameterization trick
    if not test:
        # training with reparameterization trick
        epsilon = F.randn(mu=0, sigma=1, shape=(batch_size, ) + shape_z)
        z = mu + sigma * epsilon
    else:
        # test without randomness
        z = mu

    #############################################
    # Decoder of 2 fully connected layers       #
    #############################################

    # 2 fully connected layers, and Elu replaced from original Softplus.
    h = F.elu(PF.affine(z, (500, ), name='fc3'))
    h = F.elu(PF.affine(h, (500, ), name='fc4'))

    # The outputs are the parameters of Bernoulli probabilities for each pixel.
    prob = PF.affine(h, (1, 28, 28), name='fc5')

    #############################################
    # Elbo components and loss objective        #
    #############################################

    # Binarized input
    xb = F.greater_equal_scalar(xa, 0.5)

    # E_q(z|x)[log(q(z|x))]
    # without some constant terms that will canceled after summation of loss
    logqz = 0.5 * F.sum(1.0 + logvar, axis=1)

    # E_q(z|x)[log(p(z))]
    # without some constant terms that will canceled after summation of loss
    logpz = 0.5 * F.sum(mu * mu + sigma * sigma, axis=1)

    # E_q(z|x)[log(p(x|z))]
    logpx = F.sum(F.sigmoid_cross_entropy(prob, xb), axis=(1, 2, 3))

    # Vae loss, the negative evidence lowerbound
    loss = F.mean(logpx + logpz - logqz)

    return loss
Пример #21
0
    def train_loss(self, model, x_start, t, noise=None):
        """
        Calculate training loss for given data and model.

        Args:
            model (callable): 
                A trainable model to predict noise in data conditioned by timestep.
                This function should perform like pred_noise = model(x_noisy, t).
                If self.model_var_type is the one that requires prediction for sigma, model has to output them as well.
            x_start (nn.Variable): The (B, C, ...) tensor of x_0.
            t (nn.Variable): A 1-D tensor of timesteps.
            noise (callable or None): A noise generator. If None, F.randn(shape=x_start.shape) will be used.

        Return:
            loss (dict of {string: nn.Variable}): 
                Return dict that has losses to train the `model`.
                You can access each loss by a name that will be:
                    - `vlb`: Variational Lower Bound for learning sigma. 
                             This will be included only if self.model_var_type requires to learn sigma.
                    - `mse`: MSE between actual and predicted noise.
                Each entry is the (B, ) tensor of batched loss computed from given inputs.
                Note that this function doesn't reduce batch dim
                in order to make it easy to trace the loss value at each timestep.
                Therefore, you should take average for returned Variable over batch dim to train the model.
        """
        B, C, H, W = x_start.shape
        assert t.shape == (B, )

        if noise is None:
            noise = F.randn(shape=x_start.shape)
        assert noise.shape == x_start.shape

        # Calculate x_t from x_start, t, and noise.
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        assert x_noisy.shape == x_start.shape

        # Predict noise.
        # According to the original DDPM, this is superior than reconstructing x_0.
        # If model_var_type requires to learn sigma, model must output pred_sigma as well.
        pred = model(x_noisy, t)

        # Calculate losses
        ret = AttrDict()

        if is_learn_sigma(self.model_var_type):
            # split pred into 2 variables along channel axis.
            pred_noise, pred_sigma = chunk(pred, num_chunk=2, axis=1)
            assert pred_sigma.shape == x_start.shape, \
                f"Shape mismutch between pred_sigma {pred_sigma.shape} and x_start {x_start.shape}"

            # Variational lower bound for sigma
            # Use dummy function as model, since we already got prediction from model.
            var = F.concatenate(pred_noise.get_unlinked_variable(
                need_grad=False), pred_sigma, axis=1)
            ret.vlb = self._vlb_in_bits_per_dims(model=lambda x_t, t: var,
                                                 x_start=x_start,
                                                 x_t=x_noisy,
                                                 t=t)
        else:
            pred_noise = pred

        assert pred_noise.shape == x_start.shape, \
            f"Shape mismutch between pred_noise {pred_sigma.shape} and x_start {x_start.shape}"

        ret.mse = mean_along_except_batch(F.squared_error(noise, pred_noise))

        # shape check for all losses
        for name, loss in ret.items():
            assert loss.shape == (B, ), \
                 f"A Variabla for loss `{name}` has a wrong shape ({loss.shape} != {(B, )})"

        return ret
Пример #22
0
    def build_static_graph(self):
        real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size,
                                      self.img_size))
        noises = [
            F.randn(shape=(self.batch_size, self.config['latent_dim']))
            for _ in range(2)
        ]
        if self.few_shot_config['common']['type'] == 'cdc':
            NT_class = NoiseTop(n_train=self.train_loader.size,
                                latent_dim=self.config['latent_dim'],
                                batch_size=self.batch_size)
            noises = NT_class()
            self.PD_switch_var = NT_class.PD_switch_var

        if self.config['regularize_gen']:
            fake_img, dlatents = self.generator(self.batch_size,
                                                noises,
                                                return_latent=True)
        else:
            fake_img = self.generator(self.batch_size, noises)
        fake_img_test = self.generator_ema(self.batch_size, noises)

        if self.few_shot_config['common']['type'] != 'cdc':
            fake_disc_out = self.discriminator(fake_img)
            real_disc_out = self.discriminator(real_img)
            disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out)

        gen_loss = 0
        if self.few_shot_config['common']['type'] == 'cdc':
            fake_img_s = self.generator_s(self.batch_size, noises)
            cdc_loss = CrossDomainCorrespondence(
                fake_img,
                fake_img_s,
                _choice_num=self.few_shot_config['cdc']['feature_num'],
                _layer_fix_switch=self.few_shot_config['cdc']['layer_fix'])
            gen_loss += self.few_shot_config['cdc']['lambda'] * cdc_loss
            # --- PatchDiscriminator ---
            fake_disc_out, fake_feature_var = self.discriminator(
                fake_img, patch_switch=True, index=0)
            real_disc_out, real_feature_var = self.discriminator(
                real_img, patch_switch=True, index=0)
            disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out)
            disc_loss_patch = disc_logistic_loss(fake_feature_var,
                                                 real_feature_var)
            disc_loss += self.PD_switch_var * disc_loss_patch

        gen_loss += gen_nonsaturating_loss(fake_disc_out)

        var_name_list = [
            'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss',
            'fake_disc_out', 'real_disc_out', 'fake_img_test'
        ]
        var_list = [
            real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out,
            real_disc_out, fake_img_test
        ]

        if self.config['regularize_gen']:
            dlatents.need_grad = True
            mean_path_length = nn.Variable()
            pl_reg, path_mean, _ = gen_path_regularize(
                fake_img=fake_img,
                latents=dlatents,
                mean_path_length=mean_path_length)
            path_mean_update = F.assign(mean_path_length, path_mean)
            path_mean_update.name = 'path_mean_update'
            pl_reg += 0 * path_mean_update
            gen_loss_reg = gen_loss + pl_reg
            var_name_list.append('gen_loss_reg')
            var_list.append(gen_loss_reg)

        if self.config['regularize_disc']:
            real_img.need_grad = True
            real_disc_out = self.discriminator(real_img)
            disc_loss_reg = disc_loss + self.config[
                'r1_coeff'] * 0.5 * disc_r1_loss(
                    real_disc_out, real_img) * self.config['disc_reg_step']
            real_img.need_grad = False
            var_name_list.append('disc_loss_reg')
            var_list.append(disc_loss_reg)

        Parameters = namedtuple('Parameters', var_name_list)
        self.parameters = Parameters(*var_list)
Пример #23
0
def generate_attribute_direction(args, attribute_prediction_model):

    if not os.path.isfile(os.path.join(args.weights_path, 'gen_params.h5')):
        os.makedirs(args.weights_path, exist_ok=True)
        print(
            "Downloading the pretrained tf-converted weights. Please wait...")
        url = "https://nnabla.org/pretrained-models/nnabla-examples/GANs/stylegan2/styleGAN2_G_params.h5"
        from nnabla.utils.data_source_loader import download
        download(url, os.path.join(args.weights_path, 'gen_params.h5'), False)

    nn.load_parameters(os.path.join(args.weights_path, 'gen_params.h5'))
    print('Loaded pretrained weights from tensorflow!')

    nn.load_parameters(args.classifier_weight_path)
    print(f'Loaded {args.classifier_weight_path}')

    batches = [
        args.batch_size for _ in range(args.num_images // args.batch_size)
    ]
    if args.num_images % args.batch_size != 0:
        batches.append(args.num_images -
                       (args.num_images // args.batch_size) * args.batch_size)

    w_plus, w_minus = 0.0, 0.0
    w_plus_count, w_minus_count = 0.0, 0.0
    pbar = trange(len(batches))
    for i in pbar:
        batch_size = batches[i]
        z = [F.randn(shape=(batch_size, 512)).data]

        z = [z[0], z[0]]

        for i in range(len(z)):
            z[i] = F.div2(
                z[i],
                F.pow_scalar(F.add_scalar(
                    F.mean(z[i]**2., axis=1, keepdims=True), 1e-8),
                             0.5,
                             inplace=True))

        # get latent code
        w = [mapping_network(z[0], outmaps=512, num_layers=8)]
        w += [mapping_network(z[1], outmaps=512, num_layers=8)]

        # truncation trick
        dlatent_avg = nn.parameter.get_parameter_or_create(name="dlatent_avg",
                                                           shape=(1, 512))
        w = [lerp(dlatent_avg, _, 0.7) for _ in w]

        constant_bc = nn.parameter.get_parameter_or_create(
            name="G_synthesis/4x4/Const/const", shape=(1, 512, 4, 4))
        constant_bc = F.broadcast(constant_bc,
                                  (batch_size, ) + constant_bc.shape[1:])

        gen = synthesis(w, constant_bc, noise_seed=100, mix_after=7)

        classifier_score = F.softmax(attribute_prediction_model(gen, True))
        confidence, class_pred = F.max(classifier_score,
                                       axis=1,
                                       with_index=True,
                                       keepdims=True)

        w_plus += np.sum(w[0].data * (class_pred.data == 0) *
                         (confidence.data > 0.65),
                         axis=0,
                         keepdims=True)
        w_minus += np.sum(w[0].data * (class_pred.data == 1) *
                          (confidence.data > 0.65),
                          axis=0,
                          keepdims=True)

        w_plus_count += np.sum(
            (class_pred.data == 0) * (confidence.data > 0.65))
        w_minus_count += np.sum(
            (class_pred.data == 1) * (confidence.data > 0.65))

        pbar.set_description(f'{w_plus_count} {w_minus_count}')

    # save attribute direction
    attribute_variation_direction = (w_plus / w_plus_count) - (w_minus /
                                                               w_minus_count)
    print(w_plus_count, w_minus_count)
    np.save(f'{args.classifier_weight_path.split("/")[0]}/direction.npy',
            attribute_variation_direction)
Пример #24
0
def train():
    rng = np.random.RandomState(803)

    conf = get_config()

    comm = init_nnabla(conf)

    # create data iterator
    if conf.dataset == "cityscapes":
        data_list = get_cityscape_datalist(conf.cityscapes,
                                           save_file=comm.rank == 0)
        n_class = conf.cityscapes.n_label_ids
        use_inst = True

        data_iter = create_cityscapes_iterator(conf.batch_size,
                                               data_list,
                                               comm=comm,
                                               image_shape=conf.image_shape,
                                               rng=rng,
                                               flip=conf.use_flip)

    elif conf.dataset == "ade20k":
        data_list = get_ade20k_datalist(conf.ade20k, save_file=comm.rank == 0)
        n_class = conf.ade20k.n_label_ids + 1  # class id + unknown
        use_inst = False

        load_shape = tuple(
            x + 30
            for x in conf.image_shape) if conf.use_crop else conf.image_shape
        data_iter = create_ade20k_iterator(conf.batch_size,
                                           data_list,
                                           comm=comm,
                                           load_shape=load_shape,
                                           crop_shape=conf.image_shape,
                                           rng=rng,
                                           flip=conf.use_flip)

    else:
        raise NotImplementedError(
            "Currently dataset {} is not supported.".format(conf.dataset))

    real = nn.Variable(shape=(conf.batch_size, 3) + conf.image_shape)
    obj_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape)

    if use_inst:
        ist_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape)
        obj_onehot, bm = encode_inputs(ist_mask, obj_mask, n_ids=n_class)
        mask = F.concatenate(obj_onehot, bm, axis=1)
    else:
        om = obj_mask
        if len(om.shape) == 3:
            om = F.reshape(om, om.shape + (1, ))
        obj_onehot = F.one_hot(om, shape=(n_class, ))
        mask = F.transpose(obj_onehot, (0, 3, 1, 2))

    # generator
    generator = SpadeGenerator(conf.g_ndf, image_shape=conf.image_shape)
    z = F.randn(shape=(conf.batch_size, conf.z_dim))
    fake = generator(z, mask)

    # unlinking
    ul_mask, ul_fake = get_unlinked_all(mask, fake)

    # discriminator
    discriminator = PatchGAN(n_scales=conf.d_n_scales)
    d_input_real = F.concatenate(real, ul_mask, axis=1)
    d_input_fake = F.concatenate(ul_fake, ul_mask, axis=1)
    d_real_out, d_real_feats = discriminator(d_input_real)
    d_fake_out, d_fake_feats = discriminator(d_input_fake)

    g_gan, g_feat, d_real, d_fake = discriminator.get_loss(
        d_real_out,
        d_real_feats,
        d_fake_out,
        d_fake_feats,
        use_fm=conf.use_fm,
        fm_lambda=conf.lambda_fm,
        gan_loss_type=conf.gan_loss_type)

    def _rescale(x):
        return rescale_values(x,
                              input_min=-1,
                              input_max=1,
                              output_min=0,
                              output_max=255)

    g_vgg = vgg16_perceptual_loss(_rescale(ul_fake),
                                  _rescale(real)) * conf.lambda_vgg

    set_persistent_all(fake, mask, g_gan, g_feat, d_real, d_fake, g_vgg)

    # loss
    g_loss = g_gan + g_feat + g_vgg
    d_loss = (d_real + d_fake) / 2

    # load params
    if conf.load_params is not None:
        print("load parameters from {}".format(conf.load_params))
        nn.load_parameters(conf.load_params)

    # Setup Solvers
    g_solver = S.Adam(beta1=0.)
    g_solver.set_parameters(get_params_startswith("spade_generator"))

    d_solver = S.Adam(beta1=0.)
    d_solver.set_parameters(get_params_startswith("discriminator"))

    # lr scheduler
    g_lrs = LinearDecayScheduler(start_lr=conf.g_lr,
                                 end_lr=0.,
                                 start_iter=100,
                                 end_iter=200)
    d_lrs = LinearDecayScheduler(start_lr=conf.d_lr,
                                 end_lr=0.,
                                 start_iter=100,
                                 end_iter=200)

    ipe = get_iteration_per_epoch(data_iter._size,
                                  conf.batch_size,
                                  round="ceil")

    if not conf.show_interval:
        conf.show_interval = ipe
    if not conf.save_interval:
        conf.save_interval = ipe
    if not conf.niter:
        conf.niter = 200 * ipe

    # Setup Reporter
    losses = {
        "g_gan": g_gan,
        "g_feat": g_feat,
        "g_vgg": g_vgg,
        "d_real": d_real,
        "d_fake": d_fake
    }
    reporter = Reporter(comm,
                        losses,
                        conf.save_path,
                        nimage_per_epoch=min(conf.batch_size, 5),
                        show_interval=10)
    progress_iterator = trange(conf.niter, disable=comm.rank > 0)
    reporter.start(progress_iterator)

    colorizer = Colorize(n_class)

    # output all config and dump to file
    if comm.rank == 0:
        conf.dump_to_stdout()
        write_yaml(os.path.join(conf.save_path, "config.yaml"), conf)

    epoch = 0
    for itr in progress_iterator:
        if itr % ipe == 0:
            g_lr = g_lrs(epoch)
            d_lr = d_lrs(epoch)
            g_solver.set_learning_rate(g_lr)
            d_solver.set_learning_rate(d_lr)
            if comm.rank == 0:
                print(
                    "\n[epoch {}] update lr to ... g_lr: {}, d_lr: {}".format(
                        epoch, g_lr, d_lr))

            epoch += 1

        if conf.dataset == "cityscapes":
            im, ist, obj = data_iter.next()
            ist_mask.d = ist
        elif conf.dataset == "ade20k":
            im, obj = data_iter.next()
        else:
            raise NotImplemented()

        real.d = im
        obj_mask.d = obj

        # text embedding and create fake
        fake.forward()

        # update discriminator
        d_solver.zero_grad()
        d_loss.forward()
        d_loss.backward(clear_buffer=True)
        comm.all_reduced_solver_update(d_solver)

        # update generator
        ul_fake.grad.zero()
        g_solver.zero_grad()
        g_loss.forward()
        g_loss.backward(clear_buffer=True)

        # backward generator
        fake.backward(grad=None, clear_buffer=True)
        comm.all_reduced_solver_update(g_solver)

        # report iteration progress
        reporter()

        # report epoch progress
        show_epoch = itr // conf.show_interval
        if (itr % conf.show_interval) == 0:
            show_images = {
                "RealImages": real.data.get_data("r").transpose((0, 2, 3, 1)),
                "ObjectMask": colorizer(obj).astype(np.uint8),
                "GeneratedImage": fake.data.get_data("r").transpose(
                    (0, 2, 3, 1))
            }

            reporter.step(show_epoch, show_images)

        if (itr % conf.save_interval) == 0 and comm.rank == 0:
            nn.save_parameters(
                os.path.join(conf.save_path,
                             'param_{:03d}.h5'.format(show_epoch)))

    if comm.rank == 0:
        nn.save_parameters(os.path.join(conf.save_path, 'param_final.h5'))
Пример #25
0
    def project(self, args):
        nn.set_auto_forward(True)

        # Input Image Variable
        image = Image.open(args.img_path).convert("RGB").resize(
            (256, 256), resample=Image.BILINEAR)
        image = np.array(image) / 255.0
        image = np.transpose(image.astype(np.float32), (2, 0, 1))
        image = np.expand_dims(image, 0)
        image = (image - 0.5) / (0.5)
        image = nn.Variable.from_numpy_array(image)

        # Get Latent Space Mean and Std. Dev.
        # Get Noise for B network
        z = F.randn(shape=(self.n_latent, self.latent_dim)).data
        w = mapping_network(z)
        latent_mean = F.mean(w, axis=0, keepdims=True)
        latent_std = F.pow_scalar(F.mean(F.pow_scalar(w - latent_mean, 2)),
                                  0.5)

        # Get Noise
        noises = [F.randn(shape=(1, 1, 4, 4)).data]

        for res in self.generator.resolutions[1:]:
            for _ in range(2):
                shape = (1, 1, res, res)
                noises.append(F.randn(shape=shape).data)

        # Prepare parameters to be optimized
        latent_in = nn.Variable.from_numpy_array(
            latent_mean.data).apply(need_grad=True)
        noises = [
            nn.Variable.from_numpy_array(n.data).apply(need_grad=True)
            for n in noises
        ]

        constant_bc = nn.parameter.get_parameter_or_create(
            name="G_synthesis/4x4/Const/const", shape=(1, 512, 4, 4))
        constant_bc = F.broadcast(constant_bc, (1, ) + constant_bc.shape[1:])

        pbar = tqdm(range(self.num_iters))
        for i in pbar:

            t = i / self.num_iters
            self.set_lr(t)

            noise_strength = latent_std * 0.05 * max(0, 1 - t / 0.75)**2
            latent_n = self.latent_noise(latent_in, noise_strength)

            gen_out = self.generator.synthesis([latent_n, latent_n],
                                               constant_bc,
                                               noises_in=noises)
            N, C, H, W = gen_out.shape
            factor = H // 256
            gen_out = F.reshape(
                gen_out, (N, C, H // factor, factor, W // factor, factor))
            gen_out = F.mean(gen_out, axis=(3, 5))

            p_loss = F.sum(self.lpips_distance(image, gen_out))
            n_loss = self.regularize_noise(noises)
            mse_loss = F.mean((gen_out - image)**2)
            loss = p_loss + self.n_c * n_loss + self.mse_c * mse_loss

            param_dict = {'latent': latent_in}
            for i in range(len(noises)):
                param_dict[f'noise_{i}'] = noises[i]
            self.solver.zero_grad()
            self.solver.set_parameters(param_dict,
                                       reset=False,
                                       retain_state=True)

            loss.backward()
            self.solver.update()

            noises = self.normalize_noises(noises)

            pbar.set_description(f'Loss: {loss.d} P Loss: {p_loss.d}')

        save_generations(image, 'original.png')

        gen_out = self.generator.synthesis([latent_n, latent_n],
                                           constant_bc,
                                           noises_in=noises)
        N, C, H, W = gen_out.shape
        factor = H // 256
        gen_out = F.reshape(gen_out,
                            (N, C, H // factor, factor, W // factor, factor),
                            inplace=True)
        gen_out = F.mean(gen_out, axis=(3, 5))
        save_generations(gen_out, 'projected.png')

        nn.save_parameters('projection_params.h5', param_dict)
Пример #26
0
    def _transition(self, ecpoch_per_resolution):
        batch_size = self.di.batch_size
        resolution = self.gen.resolution_list[-1]
        phase = "{}to{}".format(
            self.gen.resolution_list[-2], self.gen.resolution_list[-1])
        logger.info("phase : {}".format(phase))

        kernel_size = self.resolution_list[-1] // resolution
        kernel = (kernel_size, kernel_size)

        total_itr = (self.di.size // batch_size + 1) * ecpoch_per_resolution
        global_itr = 1.
        alpha = global_itr / total_itr

        for epoch in range(ecpoch_per_resolution):
            logger.info("epoch : {}".format(epoch + 1))
            itr = 0
            current_epoch = self.di.epoch
            while self.di.epoch == current_epoch:
                img, _ = self.di.next()
                x = nn.Variable.from_numpy_array(img)

                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha, test=True)
                y.unlinked()
                y.need_grad = False
                x_r = F.average_pooling(x, kernel=kernel)

                p_real = self.dis.transition(x_r, alpha)
                p_fake = self.dis.transition(y, alpha)

                loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.)
                                  + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight)

                if itr % self.n_critic + 1 == self.n_critic:
                    with nn.parameter_scope("discriminator"):
                        self.solver_dis.set_parameters(nn.get_parameters(),
                                                       reset=False, retain_state=True)
                        self.solver_dis.zero_grad()
                        loss_dis.backward(clear_buffer=True)
                        self.solver_dis.update()

                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha, test=False)
                p_fake = self.dis.transition(y, alpha)

                loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2))
                with nn.parameter_scope("generator"):
                    self.solver_gen.set_parameters(
                        nn.get_parameters(), reset=False, retain_state=True)
                    self.solver_gen.zero_grad()
                    loss_gen.backward(clear_buffer=True)
                    self.solver_gen.update()

                itr += 1
                global_itr += 1.
                alpha = global_itr / total_itr

            if epoch % self.save_image_interval + 1 == self.save_image_interval:
                z = nn.Variable.from_numpy_array(self.z_test)
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha)
                img_name = "phase_{}_epoch_{}".format(phase, epoch + 1)
                self.monitor_image_tile.add(
                    img_name, F.unpooling(y, kernel=kernel))
Пример #27
0
    def _train(self, ecpoch_per_resolution, each_save=False):
        batch_size = self.di.batch_size
        resolution = self.gen.resolution_list[-1]
        logger.info("phase : {}".format(resolution))

        kernel_size = self.resolution_list[-1] // resolution
        kernel = (kernel_size, kernel_size)

        img_name = "original_phase_{}".format(resolution)
        img, _ = self.di.next()
        self.monitor_image_tile.add(img_name, img)

        for epoch in range(ecpoch_per_resolution):
            logger.info("epoch : {}".format(epoch + 1))
            itr = 0
            current_epoch = self.di.epoch
            while self.di.epoch == current_epoch:
                img, _ = self.di.next()
                x = nn.Variable.from_numpy_array(img)
                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=True)

                y.unlinked()
                y.need_grad = False
                x_r = F.average_pooling(x, kernel=kernel)

                p_real = self.dis(x_r)
                p_fake = self.dis(y)
                p_real.persistent, p_fake.persistent = True, True

                loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.)
                                  + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight)
                loss_dis.persistent = True

                if itr % self.n_critic + 1 == self.n_critic:
                    with nn.parameter_scope("discriminator"):
                        self.solver_dis.set_parameters(nn.get_parameters(),
                                                       reset=False, retain_state=True)
                        self.solver_dis.zero_grad()
                        loss_dis.backward(clear_buffer=True)
                        self.solver_dis.update()
                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=False)
                p_fake = self.dis(y)
                p_fake.persistent = True

                loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2.))
                loss_gen.persistent = True

                with nn.parameter_scope("generator"):
                    self.solver_gen.set_parameters(nn.get_parameters(),
                                                   reset=False, retain_state=True)
                    self.solver_gen.zero_grad()
                    loss_gen.backward(clear_buffer=True)
                    self.solver_gen.update()

                # Monitor
                self.monitor_p_real.add(
                    self.global_itr, p_real.d.copy().mean())
                self.monitor_p_fake.add(
                    self.global_itr, p_fake.d.copy().mean())
                self.monitor_loss_dis.add(self.global_itr, loss_dis.d.copy())
                self.monitor_loss_gen.add(self.global_itr, loss_gen.d.copy())

                itr += 1
                self.global_itr += 1

            if epoch % self.save_image_interval + 1 == self.save_image_interval:
                z = nn.Variable.from_numpy_array(self.z_test)
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=True)
                img_name = "phase_{}_epoch_{}".format(resolution, epoch + 1)
                self.monitor_image_tile.add(
                    img_name, F.unpooling(y, kernel=kernel))

            if each_save:
                self.gen.save_parameters(self.monitor_path, "Gen_phase_{}_epoch_{}".format(
                    self.resolution_list[-1], epoch+1))
                self.dis.save_parameters(self.monitor_path, "Dis_phase_{}_epoch_{}".format(
                    self.resolution_list[-1], epoch+1))
Пример #28
0
    def infer(self, mels, sigma=0.9):
        r"""Returns the generated audio.

        Args:
            mels (nn.Variable): Inputs containing mel-spectrograms of shape(B, n_mels, Ty).
                Defaults to None. If None, the mel spectrograms are infferred from data.
            sigma (float, optional): Sigma used to infer audio. Defaults to 0.9.

        Returns:
            nn.Variable: A synthetic audio.
        """

        hp = self.hparams
        with nn.parameter_scope('', self.parameter_scope):

            #  Upsample spectrogram to size of audio
            with nn.parameter_scope('upsample'):
                with nn.parameter_scope('deconv'):
                    mels = PF.deconvolution(mels,
                                            hp.n_mels,
                                            kernel=(1024, ),
                                            stride=(256, ))
                # cutout conv artifacts
                mels = mels[..., :-(1024 - 256)]  # kernel - stride

                # transforming to correct shape
                mels = F.reshape(mels,
                                 mels.shape[:2] + (-1, hp.n_samples_per_group))
                mels = F.transpose(mels, (0, 2, 1, 3))
                mels = F.reshape(mels, mels.shape[:2] + (-1, ))
                # (B, n_mels * n_groups, L/n_groups)
                mels = F.transpose(mels, (0, 2, 1))

            wave = F.randn(shape=(mels.shape[0], self.n_remaining_channels,
                                  mels.shape[2])) * sigma

            for k in reversed(range(hp.n_flows)):
                n_half = wave.shape[1] // 2
                audio_0 = wave[:, :n_half, :]
                audio_1 = wave[:, n_half:, :]

                with nn.parameter_scope(f'wn_{k}'):
                    output = getattr(self, f'WN_{k}')(audio_0, mels)
                    s = output[:, n_half:, :]
                    b = output[:, :n_half, :]
                    audio_1 = (audio_1 - b) / F.exp(s)
                    wave = F.concatenate(audio_0, audio_1, axis=1)

                wave = invertible_conv(wave,
                                       reverse=True,
                                       rng=self.rng,
                                       scope=f'inv_{k}')

                if k % hp.n_early_every == 0 and k > 0:
                    z = F.randn(shape=(mels.shape[0], hp.n_early_size,
                                       mels.shape[2]))
                    wave = F.concatenate(sigma * z, wave, axis=1)

            wave = F.transpose(wave, (0, 2, 1))
            wave = F.reshape(wave, (wave.shape[0], -1))

        return wave
Пример #29
0
def generate(args):
    # Load model
    nn.load_parameters(args.model_load_path)

    # Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = 1, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])
    one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5)

    # Model
    maps = args.maps
    # content/style (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    # content/style (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(
        shape=x_style_a.shape) if not args.example_guided else x_style_a
    z_style_a = z_style_a.apply(persistent=True)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(
        shape=x_style_b.shape) if not args.example_guided else x_style_b
    z_style_b = z_style_b.apply(persistent=True)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")

    # Monitor
    suffix = "Stochastic" if not args.example_guided else "Example-guided"
    monitor = Monitor(args.monitor_path)
    monitor_image_a = MonitorImage("Fake Image B to A {} Valid".format(suffix),
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B {} Valid".format(suffix),
                                   monitor,
                                   interval=1)

    # DataIterator
    di_a = munit_data_iterator(args.img_path_a, args.batch_size)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size)

    # Generate all
    # generate (A -> B)
    if args.example_guided:
        x_real_b.d = di_b.next()[0]
    for i in range(di_a.size):
        x_real_a.d = di_a.next()[0]
        images = []
        images.append(x_real_a.d.copy())
        for _ in range(args.num_repeats):
            x_fake_b.forward(clear_buffer=True)
            images.append(x_fake_b.d.copy())
        monitor_image_b.add(i, np.concatenate(images, axis=3))

    # generate (B -> A)
    if args.example_guided:
        x_real_a.d = di_a.next()[0]
    for i in range(di_b.size):
        x_real_b.d = di_b.next()[0]
        images = []
        images.append(x_real_b.d.copy())
        for _ in range(args.num_repeats):
            x_fake_a.forward(clear_buffer=True)
            images.append(x_fake_a.d.copy())
        monitor_image_a.add(i, np.concatenate(images, axis=3))
Пример #30
0
    def synthesis(self, w_mixed, constant_bc, seed=-1, noises_in=None):

        batch_size = w_mixed.shape[0]

        if noises_in is None:
            noise = F.randn(shape=(batch_size, 1, 4, 4), seed=seed)
        else:
            noise = noises_in[0]
        w = F.reshape(F.slice(w_mixed,
                              start=(0, 0, 0),
                              stop=(w_mixed.shape[0], 1, w_mixed.shape[2]),
                              step=(1, 1, 1)),
                      (w_mixed.shape[0], w_mixed.shape[2]),
                      inplace=False)
        h = styled_conv_block(constant_bc,
                              w,
                              noise,
                              res=self.resolutions[0],
                              outmaps=self.feature_map_dim,
                              namescope="Conv")
        torgb = styled_conv_block(h,
                                  w,
                                  noise=None,
                                  res=self.resolutions[0],
                                  outmaps=3,
                                  inmaps=self.feature_map_dim,
                                  kernel_size=1,
                                  pad_size=0,
                                  demodulate=False,
                                  namescope="ToRGB",
                                  act=F.identity)

        # initial feature maps
        outmaps = self.feature_map_dim
        inmaps = self.feature_map_dim

        downsize_index = 4 if self.resolutions[-1] in [512, 1024] else 3

        # resolution 8 x 8 - 1024 x 1024
        for i in range(1, len(self.resolutions)):

            i1 = (2 + i) * 2 - 5
            i2 = (2 + i) * 2 - 4
            i3 = (2 + i) * 2 - 3
            w_ = F.reshape(F.slice(w_mixed,
                                   start=(0, i1, 0),
                                   stop=(w_mixed.shape[0], i1 + 1,
                                         w_mixed.shape[2]),
                                   step=(1, 1, 1)),
                           w.shape,
                           inplace=False)
            if i > downsize_index:
                outmaps = outmaps // 2
            curr_shape = (batch_size, 1, self.resolutions[i],
                          self.resolutions[i])
            if noises_in is None:
                noise = F.randn(shape=curr_shape, seed=seed)
            else:
                noise = noises_in[2 * i - 1]

            h = styled_conv_block(h,
                                  w_,
                                  noise,
                                  res=self.resolutions[i],
                                  outmaps=outmaps,
                                  inmaps=inmaps,
                                  kernel_size=3,
                                  up=True,
                                  namescope="Conv0_up")

            w_ = F.reshape(F.slice(w_mixed,
                                   start=(0, i2, 0),
                                   stop=(w_mixed.shape[0], i2 + 1,
                                         w_mixed.shape[2]),
                                   step=(1, 1, 1)),
                           w.shape,
                           inplace=False)
            if i > downsize_index:
                inmaps = inmaps // 2
            if noises_in is None:
                noise = F.randn(shape=curr_shape, seed=seed)
            else:
                noise = noises_in[2 * i]
            h = styled_conv_block(h,
                                  w_,
                                  noise,
                                  res=self.resolutions[i],
                                  outmaps=outmaps,
                                  inmaps=inmaps,
                                  kernel_size=3,
                                  pad_size=1,
                                  namescope="Conv1")

            w_ = F.reshape(F.slice(w_mixed,
                                   start=(0, i3, 0),
                                   stop=(w_mixed.shape[0], i3 + 1,
                                         w_mixed.shape[2]),
                                   step=(1, 1, 1)),
                           w.shape,
                           inplace=False)
            curr_torgb = styled_conv_block(h,
                                           w_,
                                           noise=None,
                                           res=self.resolutions[i],
                                           outmaps=3,
                                           inmaps=inmaps,
                                           kernel_size=1,
                                           pad_size=0,
                                           demodulate=False,
                                           namescope="ToRGB",
                                           act=F.identity)

            torgb = F.add2(curr_torgb, upsample_2d(torgb, k=[1, 3, 3, 1]))

        return torgb
Пример #31
0
def train(args):
    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = args.batch_size, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])

    # Model
    # workaround for starting with the same model among devices.
    np.random.seed(412)
    maps = args.maps
    # within-domain reconstruction (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a")
    # within-domain reconstruction (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(shape=x_style_a.shape)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a")
    x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(shape=x_style_b.shape)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")
    x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b")
    x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b")
    # discriminate (domain A)
    p_x_fake_a_list = discriminators(x_fake_a)
    p_x_real_a_list = discriminators(x_real_a)
    p_x_fake_b_list = discriminators(x_fake_b)
    p_x_real_b_list = discriminators(x_real_b)

    # Loss
    # within-domain reconstruction
    loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True)
    loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True)
    # content and style reconstruction
    loss_recon_x_style_a = recon_loss(x_style_rec_a,
                                      z_style_a).apply(persistent=True)
    loss_recon_x_content_b = recon_loss(x_content_rec_b,
                                        x_content_b).apply(persistent=True)
    loss_recon_x_style_b = recon_loss(x_style_rec_b,
                                      z_style_b).apply(persistent=True)
    loss_recon_x_content_a = recon_loss(x_content_rec_a,
                                        x_content_a).apply(persistent=True)

    # adversarial

    def f(x, y):
        return x + y

    loss_gen_a = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_a_list]).apply(persistent=True)
    loss_dis_a = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list)
    ]).apply(persistent=True)
    loss_gen_b = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_b_list]).apply(persistent=True)
    loss_dis_b = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list)
    ]).apply(persistent=True)
    # loss for generator-related models
    loss_gen = loss_gen_a + loss_gen_b \
        + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \
        + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \
        + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b)
    # loss for discriminators
    loss_dis = loss_dis_a + loss_dis_b

    # Solver
    lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2
    # solver for generator-related models
    solver_gen = S.Adam(lr_g, beta1, beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
    solver_gen.set_parameters(params_gen)
    # solver for discriminators
    solver_dis = S.Adam(lr_d, beta1, beta2)
    with nn.parameter_scope("discriminators"):
        params_dis = nn.get_parameters()
    solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    # time
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    # reconstruction
    monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A",
                                                 monitor,
                                                 interval=10)
    monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B",
                                                 monitor,
                                                 interval=10)
    # adversarial
    monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10)
    monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10)
    monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10)
    monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10)
    monitor_losses = [
        # reconstruction
        (monitor_loss_recon_x_a, loss_recon_x_a),
        (monitor_loss_recon_x_content_b, loss_recon_x_content_b),
        (monitor_loss_recon_x_style_a, loss_recon_x_style_a),
        (monitor_loss_recon_x_b, loss_recon_x_b),
        (monitor_loss_recon_x_content_a, loss_recon_x_content_a),
        (monitor_loss_recon_x_style_b, loss_recon_x_style_b),
        # adaversarial
        (monitor_loss_gen_a, loss_gen_a),
        (monitor_loss_dis_a, loss_dis_a),
        (monitor_loss_gen_b, loss_gen_b),
        (monitor_loss_dis_b, loss_dis_b)
    ]
    # image
    monitor_image_a = MonitorImage("Fake Image B to A Train",
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B Train",
                                   monitor,
                                   interval=1)
    monitor_images = [
        (monitor_image_a, x_fake_a),
        (monitor_image_b, x_fake_b),
    ]

    # DataIterator
    rng_a = np.random.RandomState(device_id)
    rng_b = np.random.RandomState(device_id + n_devices)
    di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b)

    # Train
    for i in range(args.max_iter // n_devices):
        ii = i * n_devices
        # Train generator-related models
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_gen.values()])
        solver_gen.weight_decay(args.weight_decay_rate)
        solver_gen.update()

        # Train discriminators
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        x_fake_a.need_grad, x_fake_b.need_grad = False, False
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_dis.values()])
        solver_dis.weight_decay(args.weight_decay_rate)
        solver_dis.update()
        x_fake_a.need_grad, x_fake_b.need_grad = True, True

        # LR schedule
        if (i + 1) % (args.lr_decay_at_every // n_devices) == 0:
            lr_d = solver_dis.learning_rate() * args.lr_decay_rate
            lr_g = solver_gen.learning_rate() * args.lr_decay_rate
            solver_dis.set_learning_rate(lr_d)
            solver_gen.set_learning_rate(lr_g)

        if mpi_local_rank == 0:
            # Monitor
            monitor_time.add(ii)
            for mon, loss in monitor_losses:
                mon.add(ii, loss.d)
            # Save
            if (i + 1) % (args.model_save_interval // n_devices) == 0:
                for mon, x in monitor_images:
                    mon.add(ii, x.d)
                nn.save_parameters(
                    os.path.join(args.monitor_path,
                                 "param_{:05d}.h5".format(i)))

    if mpi_local_rank == 0:
        # Monitor
        for mon, loss in monitor_losses:
            mon.add(ii, loss.d)
        # Save
        for mon, x in monitor_images:
            mon.add(ii, x.d)
        nn.save_parameters(
            os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))
Пример #32
0
    def sample_loop(self, model, shape, sampler,
                    noise=None,
                    dump_interval=-1,
                    progress=False,
                    without_auto_forward=False):
        """
        Iteratively Sample data from model from t=T to t=0.
        T is specified as the length of betas given to __init__().

        Args:
            model (collable): 
                A callable that takes x_t and t and predict noise (and sigma related parameters).
            shape (list like object): A data shape.
            sampler (callable): A function to sample x_{t-1} given x_{t} and t. Typically, self.p_sample or self.ddim_sample.
            noise (collable): A noise generator. If None, F.randn(shape) will be used.
            interval (int): 
                If > 0, all intermediate results at every `interval` step will be returned as a list.
                e.g. if interval = 10, the predicted results at {10, 20, 30, ...} will be returned.
            progress (bool): If True, tqdm will be used to show the sampling progress.

        Returns:
            - x_0 (nn.Variable): the final sampled result of x_0
            - samples (a list of nn.Variable): the sampled results at every `interval`
            - pred_x_starts (a list of nn.Variable): the predicted x_0 from each x_t at every `interval`: 
        """
        T = self.num_timesteps
        indices = list(range(T))[::-1]

        samples = []
        pred_x_starts = []

        if progress:
            from tqdm.auto import tqdm
            indices = tqdm(indices)

        if without_auto_forward:
            if noise is None:
                noise = np.random.randn(*shape)
            else:
                assert isinstance(noise, np.ndarray)
                assert noise.shape == shape

            x_t = nn.Variable.from_numpy_array(noise)
            t = nn.Variable.from_numpy_array([T - 1 for _ in range(shape[0])])

            # build graph
            y, pred_x_start = sampler(model, x_t, t)
            up_x_t = F.assign(x_t, y)
            up_t = F.assign(t, t - 1)
            update = F.sink(up_x_t, up_t)

            cnt = 0
            for step in indices:
                y.forward(clear_buffer=True)
                update.forward(clear_buffer=True)

                cnt += 1
                if dump_interval > 0 and cnt % dump_interval == 0:
                    samples.append((step, y.d.copy()))
                    pred_x_starts.append((step, pred_x_start.d.copy()))
        else:
            with nn.auto_forward():
                if noise is None:
                    x_t = F.randn(shape=shape)
                else:
                    assert isinstance(noise, np.ndarray)
                    assert noise.shape == shape
                    x_t = nn.Variable.from_numpy_array(noise)
                cnt = 0
                for step in indices:
                    t = F.constant(step, shape=(shape[0], ))
                    x_t, pred_x_start = sampler(
                        model, x_t, t, no_noise=step == 0)
                    cnt += 1
                    if dump_interval > 0 and cnt % dump_interval == 0:
                        samples.append((step, x_t.d.copy()))
                        pred_x_starts.append((step, pred_x_start.d.copy()))

        assert x_t.shape == shape
        return x_t.d.copy(), samples, pred_x_starts