Пример #1
0
def compute_sample_points_for_variable_depth(ray_origins,
                                             ray_directions,
                                             near_plane,
                                             far_plane,
                                             num_samples,
                                             randomize=False):

    depth_steps = F.arange(0, 1 + 1 / num_samples, 1 / (num_samples - 1))
    depth_steps = F.broadcast(depth_steps[None, :],
                              (far_plane.shape[0], depth_steps.shape[0]))
    depth_values = near_plane[:, None] * \
        (1-depth_steps) + far_plane[:, None] * depth_steps

    if randomize:
        depth_vals_mid = 0.5 * (depth_values[:, :-1] + depth_values[:, 1:])
        # get intervals between samples
        upper = F.concatenate(depth_vals_mid, depth_values[:, -1:], axis=-1)
        lower = F.concatenate(depth_values[:, :1], depth_vals_mid, axis=-1)

        noise = F.rand(shape=depth_values.shape)
        depth_values = lower + (upper - lower) * noise

    sample_points = ray_origins[..., None, :] + \
        ray_directions[..., None, :]*depth_values[..., :, None]

    return sample_points, depth_values
Пример #2
0
def test_seed(seed):
    rand = F.rand(shape=(1, 2, 3))

    nn.seed(seed)

    # check Python random generator
    assert R.pseed == seed
    assert R.prng.rand() == np.random.RandomState(seed).rand()

    # check NNabla layer
    rand_values_before = []
    for i in range(10):
        rand.forward()
        rand_values_before.append(rand.d.copy())

    # reset seed
    nn.seed(seed)

    rand_values_after = []
    for i in range(10):
        rand.forward()
        rand_values_after.append(rand.d.copy())

    # check if global random generator is reset
    assert np.all(np.array(rand_values_before) == np.array(rand_values_after))
Пример #3
0
def test_rand_forward(seed, ctx, func_name, low, high, shape):
    with nn.context_scope(ctx):
        o = F.rand(low, high, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    assert np.all(o.d < high)
    assert np.all(o.d >= low)
Пример #4
0
def test_rand_forward(seed, ctx, func_name, low, high, shape):
    with nn.context_scope(ctx):
        o = F.rand(low, high, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    assert np.all(o.d < high)
    assert np.all(o.d >= low)
Пример #5
0
 def call(self, input):
     if self._drop_prob == 0:
         return input
     mask = F.rand(shape=(input.shape[0], 1, 1, 1))
     mask = F.greater_equal_scalar(mask, self._drop_prob)
     out = F.mul_scalar(input, 1. / (1 - self._drop_prob))
     out = F.mul2(out, mask)
     return out
Пример #6
0
 def graph(x1):
     x1 = F.identity(x1).apply(recompute=True)
     x2 = F.randn(shape=x1.shape, seed=123).apply(recompute=True)
     x3 = F.rand(shape=x1.shape, seed=456).apply(recompute=True)
     y = F.mul2(x1, x2).apply(recompute=True)
     y = F.mul2(y, x3).apply(recompute=True)
     y = F.identity(y)
     return y
Пример #7
0
def compute_sample_points_from_rays(ray_origins,
                                    ray_directions,
                                    near_plane,
                                    far_plane,
                                    num_samples,
                                    randomize=False):
    """Given a bundle of rays, this function samples points along each ray which is later used in volumetric rendering integration

    Args:
        ray_origins (nn.Variable or nn.NdArray): Shape is (height, width, 3) - Center of each ray from camera to grid point
        ray_directions (nn.Variable or nn.NdArray): Shape is (height, width, 3) - Direction of each projected ray from camera to grid point
        near_plane (float): Position of the near clipping plane
        far_plane (float): Position of the far clipping plane
        num_samples (int): Number of points to sample along each ray
        randomize (bool, optional): Defaults to True.

    Returns:
        sample_points: Shape is (height, width, num_samples, 3) - Sampled points along each ray
        depth_values: Shape is (num_samples, 1) - Depth values between the near and far plane at which point along each ray is sampled
    """

    if isinstance(near_plane, nn.Variable) or isinstance(
            near_plane, nn.NdArray):
        return compute_sample_points_for_variable_depth(
            ray_origins, ray_directions, near_plane, far_plane, num_samples,
            randomize)

    depth_values = F.arange(near_plane,
                            far_plane + (far_plane - near_plane) / num_samples,
                            (far_plane - near_plane) / (num_samples - 1))
    depth_values = F.reshape(depth_values, (1, ) + depth_values.shape)
    if randomize:
        noise_shape = ray_origins.shape[:-1] + (num_samples, )
        if len(noise_shape) == 3:
            depth_values = depth_values[None, :, :] + F.rand(
                shape=noise_shape) * (far_plane - near_plane) / num_samples
        else:
            depth_values = depth_values + \
                F.rand(shape=noise_shape) * \
                (far_plane-near_plane) / num_samples

    sample_points = ray_origins[..., None, :] + \
        ray_directions[..., None, :]*depth_values[..., :, None]

    return sample_points, depth_values
Пример #8
0
def combination(sample_num, choise_num):
    x = F.rand(shape=(sample_num, ))
    x_indices = nn.Variable.from_numpy_array(np.arange(sample_num, ) + 1)
    y_top_k = F.top_k_data(x, k=choise_num, reduce=False, base_axis=0)
    y_top_k_sign = F.sign(y_top_k, alpha=0)
    y_top_k_indices = F.top_k_data(y_top_k_sign * x_indices,
                                   k=choise_num,
                                   base_axis=0)
    return y_top_k_indices
Пример #9
0
 def drop_connect(self, h, p=0.2):
     if self.test:
         return h
     keep_prob = 1.0 - p
     shape = [1 if i != 0 else h.shape[0] for i in range(h.ndim)]
     r = F.rand(shape=shape)
     r += keep_prob
     m = F.floor(r)
     h = h * (m / keep_prob)
     return h
Пример #10
0
def test_rand_forward(seed, ctx, func_name, low, high, shape):
    with nn.context_scope(ctx):
        o = F.rand(low, high, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    # NOTE: The following should be < high,
    # but use <= high because std::uniform_random contains a bug.
    assert np.all(o.d <= high)
    assert np.all(o.d >= low)
Пример #11
0
def _calc_gradient_penalty(real, fake, discriminator):
    alpha = F.rand(shape=(1, 1, 1, 1))
    interpolates = alpha * real + (1.0 - alpha) * fake
    interpolates.need_grad = True

    disc_interpolates = discriminator(x=interpolates)

    grads = nn.grad([disc_interpolates], [interpolates])
    norms = [F.sum(g ** 2.0, axis=1) ** 0.5 for g in grads]
    return sum([F.mean((norm - 1.0) ** 2.0) for norm in norms])
Пример #12
0
def sample_pdf(bins, weights, N_samples, det=False):
    """Sample additional points for training fine network

    Args:
      bins: int. Height in pixels.
      weights: int. Width in pixels.
      N_samples: float. Focal length of pinhole camera.
      det

    Returns:
      samples: array of shape [batch_size, 3]. Depth samples for fine network
    """
    weights += 1e-5
    pdf = weights / F.sum(weights, axis=-1, keepdims=True)

    cdf = F.cumsum(pdf, axis=-1)
    # if isinstance(pdf, nn.Variable):
    #     cdf = nn.Variable.from_numpy_array(tf.math.cumsum(pdf.d, axis=-1))
    # else:
    #     cdf = nn.Variable.from_numpy_array(tf.math.cumsum(pdf.data, axis=-1)).data
    cdf = F.concatenate(F.constant(0, cdf[..., :1].shape), cdf, axis=-1)

    if det:
        u = F.arange(0., 1., 1 / N_samples)
        u = F.broadcast(u[None, :], cdf.shape[:-1] + (N_samples, ))
        u = u.data if isinstance(cdf, nn.NdArray) else u
    else:
        u = F.rand(shape=cdf.shape[:-1] + (N_samples, ))

    indices = F.searchsorted(cdf, u, right=True)
    # if isinstance(cdf, nn.Variable):
    #     indices = nn.Variable.from_numpy_array(
    #         tf.searchsorted(cdf.d, u.d, side='right').numpy())
    # else:
    #     indices = nn.Variable.from_numpy_array(
    #         tf.searchsorted(cdf.data, u.data, side='right').numpy())
    below = F.maximum_scalar(indices - 1, 0)
    above = F.minimum_scalar(indices, cdf.shape[-1] - 1)
    indices_g = F.stack(below, above, axis=below.ndim)
    cdf_g = F.gather(cdf,
                     indices_g,
                     axis=-1,
                     batch_dims=len(indices_g.shape) - 2)
    bins_g = F.gather(bins,
                      indices_g,
                      axis=-1,
                      batch_dims=len(indices_g.shape) - 2)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = F.where(F.less_scalar(denom, 1e-5), F.constant(1, denom.shape),
                    denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
Пример #13
0
def random_scaling(x, lo, hi):
    r"""Random scaling a Variable.

    Args:
        x (nn.Variable): Input Variable.
        lo (int): Low value.
        hi (int): High value.

    Returns:
        nn.Variable: Output Variable.
    """
    shape = (x.shape[0], 1, 1)
    scale = F.rand(lo, hi, shape=shape)
    return x * scale
Пример #14
0
def drop_path(x):
    """
        The same implementation as PyTorch versions.
        rate: Variable. drop rate. if the random value drawn from
                uniform distribution is less than the drop_rate,
                corresponding element becomes 0.
    """
    drop_prob = nn.parameter.get_parameter_or_create("drop_rate",
                                                     shape=(1, 1, 1, 1), need_grad=False)
    mask = F.rand(shape=(x.shape[0], 1, 1, 1))
    mask = F.greater_equal(mask, drop_prob)
    x = F.div2(x, 1 - drop_prob)
    x = F.mul2(x, mask)
    return x
Пример #15
0
def test_rand_forward(seed, ctx, func_name, low, high, shape):
    with nn.context_scope(ctx):
        o = F.rand(low, high, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    # NOTE: The following should be < high,
    # but use <= high because std::uniform_random contains a bug.
    assert np.all(o.d <= high)
    assert np.all(o.d >= low)

    # Checking recomputation
    func_args = [low, high, shape, seed]
    recomputation_test(rng=None,
                       func=F.rand,
                       vinputs=[],
                       func_args=func_args,
                       func_kwargs={},
                       ctx=ctx)
Пример #16
0
def loss_dis_real(logits, rec_imgs, part, img, lmd=1.0):
    # loss = 0.0

    # Hinge loss (following the official implementation)
    loss = F.mean(F.relu(0.2*F.rand(shape=logits.shape) + 0.8 - logits))

    # Reconstruction loss for rec_img_big (reconstructed from 8x8 features of the original image)
    # Reconstruction loss for rec_img_small (reconstructed from 8x8 features of the resized image)
    # Reconstruction loss for rec_img_part (reconstructed from a part of 16x16 features of the original image)
    if lmd > 0.0:
        # Ground-truth
        img_128 = F.interpolate(img, output_size=(128, 128))
        img_256 = F.interpolate(img, output_size=(256, 256))

        img_half = F.where(F.greater_scalar(
            part[0], 0.5), img_256[:, :, :128, :], img_256[:, :, 128:, :])
        img_part = F.where(F.greater_scalar(
            part[1], 0.5), img_half[:, :, :, :128], img_half[:, :, :, 128:])

        # Integrated perceptual loss
        loss = loss + lmd * \
            reconstruction_loss_lpips(rec_imgs, [img_128, img_part])

    return loss
Пример #17
0
def audio(request, nb_samples, nb_channels, nb_timesteps):
    return F.rand(shape=[nb_samples, nb_channels, nb_timesteps])
Пример #18
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
Пример #19
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    lambda_ = args.lambda_

    # Model
    # generator loss
    z = nn.Variable([batch_size, latent])
    x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True)
    p_fake = discriminator(x_fake, maps=maps)
    loss_gen = gan_loss(p_fake).apply(persistent=True)
    # discriminator loss
    p_fake = discriminator(x_fake, maps=maps)
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, maps=maps)
    loss_dis = gan_loss(p_fake, p_real).apply(persistent=True)
    # gradient penalty
    eps = F.rand(shape=[batch_size, 1, 1, 1])
    x_rmix = eps * x_real + (1.0 - eps) * x_fake
    p_rmix = discriminator(x_rmix, maps=maps)
    x_rmix.need_grad = True  # Enabling gradient computation for double backward
    grads = nn.grad([p_rmix], [x_rmix])
    l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads]
    gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms])
    loss_dis += lambda_ * gp
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    x_test = generator(z_test, maps=maps, test=True,
                       up=args.up).apply(persistent=True)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)

    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_cri = MonitorSeries("Negative Critic Loss",
                                     monitor,
                                     interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=batch_size,
                                                interval=1,
                                                normalize_method=denormalize)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Data Iterator
    di = data_iterator_cifar10(batch_size, True)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need backward to generator
        for _ in range(args.n_critic):
            solver_dis.zero_grad()
            x_real.d = di.next()[0] / 127.5 - 1.0
            z.d = np.random.randn(batch_size, latent)
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(clear_buffer=True)
            solver_dis.update()

        # Train generator
        x_fake.need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        z.d = np.random.randn(batch_size, latent)
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()
        # Monitor
        monitor_loss_gen.add(i, loss_gen.d)
        monitor_loss_cri.add(i, -loss_dis.d)
        monitor_time.add(i)

        # Save
        if i % args.save_interval == 0:
            monitor_image_tile_train.add(i, x_fake)
            monitor_image_tile_test.add(i, x_test)
            nn.save_parameters(
                os.path.join(args.monitor_path, "params_{}.h5".format(i)))

    # Last
    x_test.forward(clear_buffer=True)
    nn.save_parameters(
        os.path.join(args.monitor_path, "params_{}.h5".format(i)))
    monitor_image_tile_train.add(i, x_fake)
    monitor_image_tile_test.add(i, x_test)
Пример #20
0
def idr_loss(camloc, raydir, alpha, color_gt, mask_obj, conf):
    # Setting
    B, R, _ = raydir.shape
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    # Ray trace (visibility)
    x_hit, mask_hit, dists, mask_pin, mask_pout = \
        ray_trace(partial(sdf_net, conf=conf),
                  camloc, raydir, mask_obj, t_near=conf.t_near, t_far=conf.t_far,
                  sphere_trace_itr=conf.sphere_trace_itr,
                  ray_march_points=conf.ray_march_points,
                  n_chunks=conf.n_chunks,
                  max_post_itr=conf.max_post_itr,
                  post_method=conf.post_method, eps=conf.eps)

    x_hit = x_hit.apply(need_grad=False)
    mask_hit = mask_hit.apply(need_grad=False, persistent=True)
    dists = dists.apply(need_grad=False)
    mask_pin = mask_pin.apply(need_grad=False)
    mask_pout = mask_pout.apply(need_grad=False)
    mask_us = mask_pin + mask_pout
    P = F.sum(mask_us)

    # Current points
    x_curr = (camloc.reshape((B, 1, 3)) + dists * raydir).apply(need_grad=True)

    # Eikonal loss
    bounding_box_size = conf.bounding_box_size
    x_free = F.rand(-bounding_box_size,
                    bounding_box_size,
                    shape=(B, R // 2, 3))
    x_point = F.concatenate(*[x_curr, x_free], axis=1)
    sdf_xp, _, grad_xp = sdf_feature_grad(implicit_network, x_point, conf)
    gp = (F.norm(grad_xp, axis=[grad_xp.ndim - 1], keepdims=True) - 1.0)**2.0
    loss_eikonal = F.sum(gp[:, :R, :] * mask_us) + F.sum(gp[:, R:, :])
    loss_eikonal = loss_eikonal / (P + B * R // 2)
    loss_eikonal = loss_eikonal.apply(persistent=True)

    sdf_curr = sdf_xp[:, :R, :]
    grad_curr = grad_xp[:, :R, :]

    # Mask loss
    logit = -alpha.reshape([1 for _ in range(sdf_curr.ndim)]) * sdf_curr
    loss_mask = F.sigmoid_cross_entropy(logit, mask_obj)
    loss_mask = loss_mask * mask_pout
    loss_mask = F.sum(loss_mask) / P / alpha
    loss_mask = loss_mask.apply(persistent=True)

    # Lighting
    x_hat = sample_network(x_curr, sdf_curr, raydir, grad_curr)
    _, feature, grad = sdf_feature_grad(implicit_network, x_hat, conf)
    normal = grad
    color_pred = lighting_network(x_hat, normal, feature, -raydir, D)

    # Color loss
    loss_color = F.absolute_error(color_gt, color_pred)
    loss_color = loss_color * mask_pin
    loss_color = F.sum(loss_color) / P
    loss_color = loss_color.apply(persistent=True)

    # Total loss
    loss = loss_color + conf.mask_weight * \
        loss_mask + conf.eikonal_weight * loss_eikonal

    return loss, loss_color, loss_mask, loss_eikonal, mask_hit
Пример #21
0
def augment(batch, aug_list, p_aug=1.0):

    if isinstance(p_aug, float):
        p_aug = nn.Variable.from_numpy_array(p_aug * np.ones((1,)))

    if "flip" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        batch_aug = F.random_flip(batch, axes=(2, 3))
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "lrflip" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        batch_aug = F.random_flip(batch, axes=(3,))
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "translation" in aug_list and batch.shape[2] >= 8:
        rnd = F.rand(shape=[batch.shape[0], ])
        # Currently nnabla does not support random_shift with border_mode="noise"
        mask = np.ones((1, 3, batch.shape[2], batch.shape[3]))
        mask[:, :, :, 0] = 0
        mask[:, :, :, -1] = 0
        mask[:, :, 0, :] = 0
        mask[:, :, -1, :] = 0
        batch_int = F.concatenate(
            batch, nn.Variable().from_numpy_array(mask), axis=0)
        batch_int_aug = F.random_shift(batch_int, shifts=(
            batch.shape[2]//8, batch.shape[3]//8), border_mode="nearest")
        batch_aug = F.slice(batch_int_aug, start=(
            0, 0, 0, 0), stop=batch.shape)
        mask_var = F.slice(batch_int_aug, start=(
            batch.shape[0], 0, 0, 0), stop=batch_int_aug.shape)
        batch_aug = batch_aug * F.broadcast(mask_var, batch_aug.shape)
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "color" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        rnd_contrast = 1.0 + 0.5 * \
            (2.0 * F.rand(shape=[batch.shape[0], 1, 1, 1]
                          ) - 1.0)  # from 0.5 to 1.5
        rnd_brightness = 0.5 * \
            (2.0 * F.rand(shape=[batch.shape[0], 1, 1, 1]
                          ) - 1.0)  # from -0.5 to 0.5
        rnd_saturation = 2.0 * \
            F.rand(shape=[batch.shape[0], 1, 1, 1])  # from 0.0 to 2.0
        # Brightness
        batch_aug = batch + rnd_brightness
        # Saturation
        mean_s = F.mean(batch_aug, axis=1, keepdims=True)
        batch_aug = rnd_saturation * (batch_aug - mean_s) + mean_s
        # Contrast
        mean_c = F.mean(batch_aug, axis=(1, 2, 3), keepdims=True)
        batch_aug = rnd_contrast * (batch_aug - mean_c) + mean_c
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "cutout" in aug_list and batch.shape[2] >= 16:
        batch = F.random_erase(batch, prob=p_aug.d[0], replacements=(0.0, 0.0))

    return batch
Пример #22
0
def loss_dis_fake(logits):
    # Hinge loss (following the official implementation)
    loss = F.mean(F.relu(0.2*F.rand(shape=logits.shape) + 0.8 + logits))
    return loss
def one_hot_combination(sample_num, choise_num):
    x = F.rand(shape=(sample_num, ))
    y_top_k = F.top_k_data(x, k=choise_num, reduce=False, base_axis=0)
    y_top_k_sign = F.sign(y_top_k, alpha=0)
    return y_top_k_sign
Пример #24
0
def Discriminator(img, label="real", scope_name="Discriminator", ndf=64):
    with nn.parameter_scope(scope_name):
        if type(img) is not list:
            img_small = F.interpolate(img, output_size=(128, 128))
        else:
            img_small = img[1]
            img = img[0]

        def sn_w(w):
            return PF.spectral_norm(w, dim=0)

        # InitLayer: -> 256x256
        with nn.parameter_scope("init"):
            h = img
            if img.shape[2] == 1024:
                h = PF.convolution(h,
                                   ndf // 8, (4, 4),
                                   stride=(2, 2),
                                   pad=(1, 1),
                                   apply_w=sn_w,
                                   with_bias=False,
                                   name="conv1")
                h = F.leaky_relu(h, 0.2)
                h = PF.convolution(h,
                                   ndf // 4, (4, 4),
                                   stride=(2, 2),
                                   pad=(1, 1),
                                   apply_w=sn_w,
                                   with_bias=False,
                                   name="conv2")
                h = PF.batch_normalization(h)
                h = F.leaky_relu(h, 0.2)
            elif img.shape[2] == 512:
                h = PF.convolution(h,
                                   ndf // 4, (4, 4),
                                   stride=(2, 2),
                                   pad=(1, 1),
                                   apply_w=sn_w,
                                   with_bias=False,
                                   name="conv2")
                h = F.leaky_relu(h, 0.2)
            else:
                h = PF.convolution(h,
                                   ndf // 4, (3, 3),
                                   pad=(1, 1),
                                   apply_w=sn_w,
                                   with_bias=False,
                                   name="conv3")
                h = F.leaky_relu(h, 0.2)

        # Calc base features
        f_256 = h
        f_128 = DownsampleComp(f_256, ndf // 2, "down256->128")
        f_64 = DownsampleComp(f_128, ndf * 1, "down128->64")
        f_32 = DownsampleComp(f_64, ndf * 2, "down64->32")

        # Apply SLE
        f_32 = SLE(f_32, f_256, "sle256->32")
        f_16 = DownsampleComp(f_32, ndf * 4, "down32->16")
        f_16 = SLE(f_16, f_128, "sle128->16")
        f_8 = DownsampleComp(f_16, ndf * 16, "down16->8")
        f_8 = SLE(f_8, f_64, "sle64->8")

        # Conv + BN + LeakyRely + Conv -> logits (5x5)
        with nn.parameter_scope("last"):
            h = PF.convolution(f_8,
                               ndf * 16, (1, 1),
                               apply_w=sn_w,
                               with_bias=False,
                               name="conv1")
            h = PF.batch_normalization(h)
            h = F.leaky_relu(h, 0.2)
            logit_large = PF.convolution(h,
                                         1, (4, 4),
                                         apply_w=sn_w,
                                         with_bias=False,
                                         name="conv2")

        # Another path: "down_from_small" in the official code
        with nn.parameter_scope("down_from_small"):
            h_s = PF.convolution(img_small,
                                 ndf // 2, (4, 4),
                                 stride=(2, 2),
                                 pad=(1, 1),
                                 apply_w=sn_w,
                                 with_bias=False,
                                 name="conv1")
            h_s = F.leaky_relu(h_s, 0.2)
            h_s = Downsample(h_s, ndf * 1, "dfs64->32")
            h_s = Downsample(h_s, ndf * 2, "dfs32->16")
            h_s = Downsample(h_s, ndf * 4, "dfs16->8")
            fea_dec_small = h_s
            logit_small = PF.convolution(h_s,
                                         1, (4, 4),
                                         apply_w=sn_w,
                                         with_bias=False,
                                         name="conv2")

        # Concatenate logits
        logits = F.concatenate(logit_large, logit_small, axis=1)

        # Reconstruct images
        rec_img_big = SimpleDecoder(f_8, "dec_big")
        rec_img_small = SimpleDecoder(fea_dec_small, "dec_small")
        part_ax2 = F.rand(shape=(img.shape[0], ))
        part_ax3 = F.rand(shape=(img.shape[0], ))
        f_16_ax2 = F.where(F.greater_scalar(part_ax2, 0.5), f_16[:, :, :8, :],
                           f_16[:, :, 8:, :])
        f_16_part = F.where(F.greater_scalar(part_ax3, 0.5),
                            f_16_ax2[:, :, :, :8], f_16_ax2[:, :, :, 8:])
        rec_img_part = SimpleDecoder(f_16_part, "dec_part")

    if label == "real":
        return logits, [rec_img_big, rec_img_small,
                        rec_img_part], [part_ax2, part_ax3]
    else:
        return logits