コード例 #1
0
ファイル: augment.py プロジェクト: tallamjr/google-research
  def random_bg(key):
    """Generate and augment background."""
    noise_key, checker_key, fft_key, bg_sel_key, blur_key, bg_saturation_key = (
        random.split(key, 6))
    # Uniform noise background.
    noise_bg = random.uniform(noise_key, (n_clip, n_clip, 3))

    # Checkerboard background.
    checker_bg = checkerboard(checker_key, config.checker_bg_nsq, n_clip)

    # Random smoothed gaussian bg used in
    # https://distill.pub/2018/differentiable-parameterizations/#section-rgba
    fft_bg = optvis.image_sample(
        fft_key, [1, n_clip, n_clip, 3], sd=0.2, decay_power=1.5)[0]

    # Select background.
    probs = [config.noise_bg_prob, config.checker_bg_prob, config.fft_bg_prob]
    assert onp.isclose(sum(probs), 1)
    bgs = np.stack([noise_bg, checker_bg, fft_bg])
    bg = random.choice(bg_sel_key, bgs, p=np.array(probs))

    # Blur background.
    if config.get('bg_blur_std_range', None):
      min_blur, max_blur = config.bg_blur_std_range
      blur_std = random.uniform(blur_key) * (max_blur - min_blur) + min_blur
      bg = dm_pix.gaussian_blur(bg, blur_std, kernel_size=15)

    # (de)saturate background. values < 1 indicate desaturation (grayscale)
    if config.get('bg_random_saturation_range', None):
      lower, upper = config.bg_random_saturation_range
      bg = dm_pix.random_saturation(bg_saturation_key, bg, lower, upper)

    return bg
コード例 #2
0
    def aug_fg(key):
        """augment, crop, and resize foreground."""
        saturation_key, crop_key = random.split(key, 2)
        fg = z

        if config.get('fg_random_saturation_range', None):
            lower, upper = config.fg_random_saturation_range
            fg = dm_pix.random_saturation(saturation_key, fg, lower, upper)

        # Crop and resize.
        ix, iy = random.randint(crop_key, (2, ), 0, fg.shape[-3] - n_crop)
        fg = crop(fg, ix, iy, n_crop)
        fg = resize(fg, n_clip)

        acc_crop = crop(acc, ix, iy, n_crop)
        acc_crop = resize(acc_crop, n_clip)
        if config.get('min_aug_acc', 0.):
            acc_crop = np.clip(acc_crop, config.min_aug_acc)

        return fg, acc_crop