Ejemplo n.º 1
0
    def build_train_graph(self,
                          x,
                          t=None,
                          dropout=0,
                          noise=None,
                          loss_scaling=None):
        B, C, H, W = x.shape
        if self.randflip:
            x = F.random_flip(x)
            assert x.shape == (B, C, H, W)

        if t is None:
            t = F.randint(low=0,
                          high=self.diffusion.num_timesteps,
                          shape=(B, ))
            # F.randint could return high with very low prob. Workaround to avoid this.
            t = F.clip_by_value(t,
                                min=0,
                                max=self.diffusion.num_timesteps - 0.5)

        loss_dict = self.diffusion.train_loss(model=partial(self._denoise,
                                                            dropout=dropout),
                                              x_start=x,
                                              t=t,
                                              noise=noise)
        assert isinstance(loss_dict, AttrDict)

        # setup training loss
        loss_dict.batched_loss = loss_dict.mse
        if is_learn_sigma(self.model_var_type):
            assert "vlb" in loss_dict
            loss_dict.batched_loss += loss_dict.vlb * 1e-3
            # todo: implement loss aware sampler

        if loss_scaling is not None and loss_scaling > 1:
            loss_dict.batched_loss *= loss_scaling

        # setup flat training loss
        loss_dict.loss = F.mean(loss_dict.batched_loss)
        assert loss_dict.batched_loss.shape == t.shape == (B, )

        # Keep interval values to compute loss for each quantile
        t.persistent = True
        for v in loss_dict.values():
            v.persistent = True

        return loss_dict, t
Ejemplo n.º 2
0
def test_random_flip_forward_backward(seed, axes, ctx, func_name):
    from nbla_test_utils import cap_ignore_region, function_tester
    rng = np.random.RandomState(seed)
    inputs = [rng.randn(2, 3, 4).astype(np.float32)]
    i = nn.Variable(inputs[0].shape, need_grad=True)
    i.d = inputs[0]
    # NNabla forward
    with nn.context_scope(ctx), nn.auto_forward():
        o = F.random_flip(i, axes, 0, seed)
    flip_close = np.allclose(o.d, ref_flip(inputs[0], axes))
    assert flip_close or (not flip_close and np.allclose(o.d, i.d))
    assert o.parent.name == func_name

    # NNabla backward
    orig_grad = rng.randn(*i.shape).astype(i.data.dtype)
    i.g[...] = orig_grad
    o_grad = rng.randn(*i.shape).astype(i.data.dtype)
    o.g = o_grad
    o.parent.backward([i], [o])

    # Verify
    if flip_close:
        ref_grad = ref_flip(o_grad, axes)
    else:
        ref_grad = o_grad
    assert_allclose(i.g, orig_grad + ref_grad)

    # Check if accum option works
    i.g[...] = 1
    o.g = o_grad
    o.parent.backward([i], [o], [False])
    assert_allclose(i.g, ref_grad)

    # Check accum=False with NaN gradient
    i.g = np.float32('nan')
    o.parent.backward([i], [o], [False])
    assert not np.any(np.isnan(i.g))

    # Check if need_grad works
    i.g[...] = 0
    i.need_grad = False
    o.backward(o_grad)
    assert np.all(i.g == 0)
Ejemplo n.º 3
0
def test_random_flip_forward_backward(seed, axes, ctx, func_name):
    from nbla_test_utils import cap_ignore_region, function_tester
    rng = np.random.RandomState(seed)
    inputs = [rng.randn(2, 3, 4).astype(np.float32)]
    i = nn.Variable(inputs[0].shape, need_grad=True)
    i.d = inputs[0]
    # NNabla forward
    with nn.context_scope(ctx), nn.auto_forward():
        o = F.random_flip(i, axes, 0, seed)
    flip_close = np.allclose(o.d, ref_flip(inputs[0], axes))
    assert flip_close or (not flip_close and np.allclose(o.d, i.d))
    assert o.parent.name == func_name

    # NNabla backward
    orig_grad = rng.randn(*i.shape).astype(i.data.dtype)
    i.g[...] = orig_grad
    o_grad = rng.randn(*i.shape).astype(i.data.dtype)
    o.g = o_grad
    o.parent.backward([i], [o])

    # Verify
    if flip_close:
        ref_grad = ref_flip(o_grad, axes)
    else:
        ref_grad = o_grad
    assert np.allclose(i.g, orig_grad + ref_grad)

    # Check if accum option works
    i.g[...] = 1
    o.g = o_grad
    o.parent.backward([i], [o], [False])
    assert np.allclose(i.g, ref_grad)

    # Check accum=False with NaN gradient
    i.g = np.float32('nan')
    o.parent.backward([i], [o], [False])
    assert not np.any(np.isnan(i.g))

    # Check if need_grad works
    i.g[...] = 0
    i.need_grad = False
    o.backward(o_grad)
    assert np.all(i.g == 0)
Ejemplo n.º 4
0
def augment(batch, aug_list, p_aug=1.0):

    if isinstance(p_aug, float):
        p_aug = nn.Variable.from_numpy_array(p_aug * np.ones((1,)))

    if "flip" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        batch_aug = F.random_flip(batch, axes=(2, 3))
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "lrflip" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        batch_aug = F.random_flip(batch, axes=(3,))
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "translation" in aug_list and batch.shape[2] >= 8:
        rnd = F.rand(shape=[batch.shape[0], ])
        # Currently nnabla does not support random_shift with border_mode="noise"
        mask = np.ones((1, 3, batch.shape[2], batch.shape[3]))
        mask[:, :, :, 0] = 0
        mask[:, :, :, -1] = 0
        mask[:, :, 0, :] = 0
        mask[:, :, -1, :] = 0
        batch_int = F.concatenate(
            batch, nn.Variable().from_numpy_array(mask), axis=0)
        batch_int_aug = F.random_shift(batch_int, shifts=(
            batch.shape[2]//8, batch.shape[3]//8), border_mode="nearest")
        batch_aug = F.slice(batch_int_aug, start=(
            0, 0, 0, 0), stop=batch.shape)
        mask_var = F.slice(batch_int_aug, start=(
            batch.shape[0], 0, 0, 0), stop=batch_int_aug.shape)
        batch_aug = batch_aug * F.broadcast(mask_var, batch_aug.shape)
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "color" in aug_list:
        rnd = F.rand(shape=[batch.shape[0], ])
        rnd_contrast = 1.0 + 0.5 * \
            (2.0 * F.rand(shape=[batch.shape[0], 1, 1, 1]
                          ) - 1.0)  # from 0.5 to 1.5
        rnd_brightness = 0.5 * \
            (2.0 * F.rand(shape=[batch.shape[0], 1, 1, 1]
                          ) - 1.0)  # from -0.5 to 0.5
        rnd_saturation = 2.0 * \
            F.rand(shape=[batch.shape[0], 1, 1, 1])  # from 0.0 to 2.0
        # Brightness
        batch_aug = batch + rnd_brightness
        # Saturation
        mean_s = F.mean(batch_aug, axis=1, keepdims=True)
        batch_aug = rnd_saturation * (batch_aug - mean_s) + mean_s
        # Contrast
        mean_c = F.mean(batch_aug, axis=(1, 2, 3), keepdims=True)
        batch_aug = rnd_contrast * (batch_aug - mean_c) + mean_c
        batch = F.where(
            F.greater(F.tile(p_aug, batch.shape[0]), rnd), batch_aug, batch)

    if "cutout" in aug_list and batch.shape[2] >= 16:
        batch = F.random_erase(batch, prob=p_aug.d[0], replacements=(0.0, 0.0))

    return batch
Ejemplo n.º 5
0
def image_augmentation(args, img, seg):
    imgseg = F.concatenate(img, seg, axis=1)
    imgseg = F.random_crop(imgseg, shape=(args.fineSizeH, args.fineSizeW))
    if not args.no_flip:
        imgseg = F.random_flip(imgseg, axes=(3, ))
    return imgseg