Esempio n. 1
0
    def time_warp(self):

        # Reshape to [Batch_size, time, freq, 1] for sparse_image_warp func.
        self.mel_spectrogram = np.reshape(self.mel_spectrogram,
                                          (-1, self.mel_spectrogram.shape[0],
                                           self.mel_spectrogram.shape[1], 1))

        v, tau = self.mel_spectrogram.shape[1], self.mel_spectrogram.shape[2]

        horiz_line_thru_ctr = self.mel_spectrogram[0][v // 2]

        random_pt = horiz_line_thru_ctr[random.randrange(
            self.W,
            tau - self.W)]  # random point along the horizontal/time axis
        w = np.random.uniform((-self.W), self.W)  # distance

        # Source Points
        src_points = [[[v // 2, random_pt[0]]]]

        # Destination Points
        dest_points = [[[v // 2, random_pt[0] + w]]]

        self.mel_spectrogram, _ = sparse_image_warp(self.mel_spectrogram,
                                                    src_points,
                                                    dest_points,
                                                    num_boundary_points=2)

        return self.mel_spectrogram
Esempio n. 2
0
def assert_zero_shift(order, regularization, num_boundary_points):
    """Check that warping with zero displacements doesn't change the
    image."""
    batch_size = 1
    image_height = 4
    image_width = 4
    channels = 3

    image = np.random.uniform(
        size=[batch_size, image_height, image_width, channels])

    input_image = tf.constant(np.float32(image))

    control_point_locations = [[1.0, 1.0], [2.0, 2.0], [2.0, 1.0]]
    control_point_locations = tf.constant(
        np.float32(np.expand_dims(control_point_locations, 0)))

    control_point_displacements = np.zeros(
        control_point_locations.shape.as_list())
    control_point_displacements = tf.constant(
        np.float32(control_point_displacements))

    (warped_image, _) = sparse_image_warp(
        input_image,
        control_point_locations,
        control_point_locations + control_point_displacements,
        interpolation_order=order,
        regularization_weight=regularization,
        num_boundary_points=num_boundary_points,
    )

    np.testing.assert_allclose(warped_image, input_image, rtol=1e-6, atol=1e-6)
Esempio n. 3
0
def test_that_backprop_runs():
    """Making sure the gradients can be computed."""
    batch_size = 1
    image_height = 9
    image_width = 12
    image = tf.Variable(
        np.random.uniform(size=[batch_size, image_height, image_width, 3]),
        dtype=tf.float32,
    )
    control_point_locations = [[3.0, 3.0]]
    control_point_locations = tf.constant(
        np.float32(np.expand_dims(control_point_locations, 0)))
    control_point_displacements = [[0.25, -0.5]]
    control_point_displacements = tf.constant(
        np.float32(np.expand_dims(control_point_displacements, 0)))

    with tf.GradientTape() as t:
        warped_image, _ = sparse_image_warp(
            image,
            control_point_locations,
            control_point_locations + control_point_displacements,
            num_boundary_points=3,
        )

    gradients = t.gradient(warped_image, image).numpy()
    assert np.sum(np.abs(gradients)) != 0
Esempio n. 4
0
    def testThatBackpropRuns(self):
        """Run optimization to ensure that gradients can be computed."""
        self.skipTest("TODO: port to tf2.0 / eager")
        batch_size = 1
        image_height = 9
        image_width = 12
        image = tf.Variable(
            np.float32(
                np.random.uniform(
                    size=[batch_size, image_height, image_width, 3])))
        control_point_locations = [[3., 3.]]
        control_point_locations = tf.constant(
            np.float32(np.expand_dims(control_point_locations, 0)))
        control_point_displacements = [[0.25, -0.5]]
        control_point_displacements = tf.constant(
            np.float32(np.expand_dims(control_point_displacements, 0)))
        warped_image, _ = sparse_image_warp(image,
                                            control_point_locations,
                                            control_point_locations +
                                            control_point_displacements,
                                            num_boundary_points=3)

        loss = tf.reduce_mean(tf.abs(warped_image - image))
        optimizer = tf1.train.MomentumOptimizer(0.001, 0.9)
        grad = tf.gradients(loss, [image])
        grad, _ = tf.clip_by_global_norm(grad, 1.0)
        opt_func = optimizer.apply_gradients(zip(grad, [image]))
        init_op = tf1.variables.global_variables_initializer(
        )  # TODO: fix TF1 ref.

        with self.cached_session() as sess:
            sess.run(init_op)
            for _ in range(5):
                sess.run([loss, opt_func])
Esempio n. 5
0
def sparse_warp(mel_spectrogram, time_warping_para=80):
    fbank_size = tf.shape(mel_spectrogram)
    n, v = fbank_size[1], fbank_size[2]

    # Image warping control point setting.
    # Source
    pt = tf.random.uniform([], time_warping_para, n - time_warping_para,
                           tf.int32)  # radnom point along the time axis
    src_ctr_pt_freq = tf.range(v // 2)  # control points on freq-axis
    src_ctr_pt_time = tf.ones_like(
        src_ctr_pt_freq) * pt  # control points on time-axis
    src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
    src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32)

    # Destination
    w = tf.random.uniform([], -time_warping_para, time_warping_para,
                          tf.int32)  # distance
    dest_ctr_pt_freq = src_ctr_pt_freq
    dest_ctr_pt_time = src_ctr_pt_time + w
    dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
    dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32)

    # warp
    source_control_point_locations = tf.expand_dims(src_ctr_pts,
                                                    0)  # (1, v//2, 2)
    dest_control_point_locations = tf.expand_dims(dest_ctr_pts,
                                                  0)  # (1, v//2, 2)

    warped_image, _ = sparse_image_warp(mel_spectrogram,
                                        source_control_point_locations,
                                        dest_control_point_locations)
    return warped_image
def test_partially_or_fully_unknown_shape(shape, interpolation_order,
                                          num_boundary_points):
    control_point_locations = np.asarray([3.0,
                                          3.0]).reshape(1, 1,
                                                        2).astype(np.float32)
    control_point_displacements = (np.asarray([0.25, -0.5]).reshape(
        1, 1, 2).astype(np.float32))
    fn = tf.function(sparse_image_warp).get_concrete_function(
        image=tf.TensorSpec(shape=shape.image, dtype=tf.float32),
        source_control_point_locations=tf.TensorSpec(
            shape=shape.source_control_point_locations, dtype=tf.float32),
        dest_control_point_locations=tf.TensorSpec(
            shape=shape.dest_control_point_locations, dtype=tf.float32),
        interpolation_order=interpolation_order,
        num_boundary_points=num_boundary_points,
    )
    image = tf.ones(shape=shape.input, dtype=tf.float32)
    expected_output = sparse_image_warp(
        image,
        control_point_locations,
        control_point_locations + control_point_displacements,
        interpolation_order=interpolation_order,
        num_boundary_points=num_boundary_points,
    )
    output = fn(
        image,
        control_point_locations,
        control_point_locations + control_point_displacements,
        interpolation_order=interpolation_order,
        num_boundary_points=num_boundary_points,
    )
    np.testing.assert_equal(output[0].numpy(), expected_output[0].numpy())
    np.testing.assert_equal(output[1].numpy(), expected_output[1].numpy())
Esempio n. 7
0
def time_warp(spectrogram, W=80):
    # Returns the time-warped tensor given the spectrogram
    
    tau, f = spectrogram.shape

    # Source control point locations
    point = tf.random.uniform(shape = [], minval = W, maxval = tau - W, dtype = tf.int32)
    freq_at_point = tf.range(f//2) # The column of the spectorgram at point
    time_at_point = tf.ones_like(freq_at_point, dtype=tf.int32)*point # control points on the time axis 
    scpt = tf.cast(tf.stack((freq_at_point, time_at_point), axis = -1), dtype = tf.float32)
    scpt = tf.expand_dims(scpt, axis = 0)

    # Destination control point locations
    dt = tf.random.uniform(shape = [], minval = -W, maxval = W, dtype = tf.int32)
    dest_freq_at_point = freq_at_point
    dest_time_at_point = time_at_point + dt
    dcpt =  tf.cast(tf.stack((dest_freq_at_point, dest_time_at_point), axis = -1), dtype = tf.float32)
    dcpt = tf.expand_dims(dcpt, axis = 0)
    
    spect = tf.cast(tf.reshape(spectrogram, [1, *spectrogram.shape, 1]), dtype = tf.float32)
    warped_dat, _ = sparse_image_warp(spect, 
                                   source_control_point_locations = scpt, 
                                   dest_control_point_locations = dcpt,
                                   num_boundary_points=2) # Need to see if there is any way to have 1.5-equivalent
    warped_dat = tf.reshape(warped_dat, spectrogram.shape)
    return warped_dat
    def assertZeroShift(self, order, regularization, num_boundary_points):
        """Check that warping with zero displacements doesn't change the
        image."""
        batch_size = 1
        image_height = 4
        image_width = 4
        channels = 3

        image = np.random.uniform(
            size=[batch_size, image_height, image_width, channels])

        input_image = tf.constant(np.float32(image))

        control_point_locations = [[1., 1.], [2., 2.], [2., 1.]]
        control_point_locations = tf.constant(
            np.float32(np.expand_dims(control_point_locations, 0)))

        control_point_displacements = np.zeros(
            control_point_locations.shape.as_list())
        control_point_displacements = tf.constant(
            np.float32(control_point_displacements))

        (warped_image, flow) = sparse_image_warp(
            input_image,
            control_point_locations,
            control_point_locations + control_point_displacements,
            interpolation_order=order,
            regularization_weight=regularization,
            num_boundary_points=num_boundary_points)

        warped_image, input_image = self.evaluate([warped_image, input_image])
        self.assertAllClose(warped_image, input_image)
Esempio n. 9
0
def test_smiley_face():
    """Check warping accuracy by comparing to hardcoded warped images."""

    input_file = get_path_to_datafile(
        "image/tests/test_data/Yellow_Smiley_Face.png")
    input_image = load_image(input_file)
    control_points = np.asarray([
        [64, 59],
        [180 - 64, 59],
        [39, 111],
        [180 - 39, 111],
        [90, 143],
        [58, 134],
        [180 - 58, 134],
    ])  # pyformat: disable
    control_point_displacements = np.asarray([
        [-10.5, 10.5],
        [10.5, 10.5],
        [0, 0],
        [0, 0],
        [0, -10],
        [-20, 10.25],
        [10, 10.75],
    ])
    control_points = tf.constant(
        np.expand_dims(np.float32(control_points[:, [1, 0]]), 0))
    control_point_displacements = tf.constant(
        np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0))
    float_image = np.expand_dims(np.float32(input_image) / 255, 0)
    input_image = tf.constant(float_image)

    for interpolation_order in (1, 2, 3):
        for num_boundary_points in (0, 1, 4):
            warped_image, _ = sparse_image_warp(
                input_image,
                control_points,
                control_points + control_point_displacements,
                interpolation_order=interpolation_order,
                num_boundary_points=num_boundary_points,
            )

            warped_image = warped_image
            out_image = np.uint8(warped_image[0, :, :, :] * 255)
            target_file = get_path_to_datafile(
                "image/tests/test_data/Yellow_Smiley_Face_Warp-interp" +
                "-{}-clamp-{}.png".format(interpolation_order,
                                          num_boundary_points))

            target_image = load_image(target_file)

            # Check that the target_image and out_image difference is no
            # bigger than 2 (on a scale of 0-255). Due to differences in
            # floating point computation on different devices, the float
            # output in warped_image may get rounded to a different int
            # than that in the saved png file loaded into target_image.
            np.testing.assert_allclose(target_image,
                                       out_image,
                                       atol=2,
                                       rtol=1e-3)
 def loss_fn():
     warped_image, _ = sparse_image_warp(image,
                                         control_point_locations,
                                         control_point_locations +
                                         control_point_displacements,
                                         num_boundary_points=3)
     loss = tf.reduce_mean(tf.abs(warped_image - image))
     return loss
Esempio n. 11
0
def sparse_warp(mel_spectrogram, time_warping_para=80):
    print("时间扭曲")

    """
    # 参数:
      mel_spectrogram(numpy array): 你想要扭曲和屏蔽的音频文件路径.
      time_warping_para(float): 增强参数, "时间扭曲参数 W".
       如果为none, 对于LibriSpeech数据集默认为80.

    # Returns
      mel_spectrogram(numpy array): 扭曲和掩蔽后的梅尔频谱图.
      τ个时间步的log mel 频谱图
      (W,τ-W)范围内的随机点,向左或向右平移w距离,w从(0,W)的均匀分布中挑出。
    边界上有六个固定点,W是time warp parameter

    """

    fbank_size = tf.shape(mel_spectrogram)
    n, v = fbank_size[1], fbank_size[2]
    print("n: {}".format(n))
    print("v: {}".format(v))
    print("time_warping_para: {}".format(time_warping_para))
    """
    n: 256
    v: 92
    time_warping_para: 80
    pt: 105
    """
    # n为该频谱图的时间步
    # v为该频谱图的频率
    # 步骤1 : 时间扭曲
    # 图像扭曲控制点设置。
    # 源
    pt = tf.random.uniform([], time_warping_para, n - time_warping_para, tf.int32)  # radnom point along the time axis
    print("pt: {}".format(pt))  # (80,176)之间的一个随机数
    src_ctr_pt_freq = tf.range(v // 2)  # [0,46)的一系列数字control points on freq-axis
    src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt  # 返回一个全为1,形状相同的张量,control points on time-axis
    # src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
    src_ctr_pts = tf.stack((src_ctr_pt_freq, src_ctr_pt_time), -1)
    src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32)
    # 目标
    w = tf.random.uniform([], 0, time_warping_para, tf.int32)  # distance
    # ???
    # w = tf.random.uniform([], -time_warping_para, time_warping_para, tf.int32)  # distance
    dest_ctr_pt_freq = src_ctr_pt_freq
    dest_ctr_pt_time = src_ctr_pt_time + w
    # dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
    dest_ctr_pts = tf.stack((dest_ctr_pt_freq, dest_ctr_pt_time), -1)
    dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32)
    # 扭曲
    source_control_point_locations = tf.expand_dims(src_ctr_pts, 0)  # (1, v//2, 2)
    dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0)  # (1, v//2, 2)

    warped_image, _ = sparse_image_warp(mel_spectrogram,
                                        source_control_point_locations,
                                        dest_control_point_locations)
    return warped_image
Esempio n. 12
0
def sparse_warp(mel_spectrogram, time_warping_para=80):
    """Spec augmentation Calculation Function.

    'SpecAugment' have 3 steps for audio data augmentation.
    first step is time warping using Tensorflow's image_sparse_warp function.
    Second step is frequency masking, last step is time masking.

    # Arguments:
      mel_spectrogram(numpy array): audio file path of you want to warping and masking.
      time_warping_para(float): Augmentation parameter, "time warp parameter W".
        If none, default = 80 for LibriSpeech.

    # Returns
      mel_spectrogram(numpy array): warped and masked mel spectrogram.
    """

    fbank_size = tf.shape(mel_spectrogram)
    n, v = fbank_size[1], fbank_size[2]

    # Step 1 : Time warping
    # Image warping control point setting.
    # Source
    pt = tf.random.uniform([], time_warping_para, n - time_warping_para,
                           tf.int32)  # radnom point along the time axis
    src_ctr_pt_freq = tf.range(v // 2)  # control points on freq-axis
    src_ctr_pt_time = tf.ones_like(
        src_ctr_pt_freq) * pt  # control points on time-axis
    src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
    src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32)

    # Destination
    w = tf.random.uniform([], -time_warping_para, time_warping_para,
                          tf.int32)  # distance
    dest_ctr_pt_freq = src_ctr_pt_freq
    dest_ctr_pt_time = src_ctr_pt_time + w
    dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
    dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32)

    # warp
    source_control_point_locations = tf.expand_dims(src_ctr_pts,
                                                    0)  # (1, v//2, 2)
    dest_control_point_locations = tf.expand_dims(dest_ctr_pts,
                                                  0)  # (1, v//2, 2)

    warped_image, _ = sparse_image_warp(tf.convert_to_tensor(mel_spectrogram),
                                        source_control_point_locations,
                                        dest_control_point_locations,
                                        regularization_weight=1e-5,
                                        num_boundary_points=1)
    return warped_image.numpy()
Esempio n. 13
0
    def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use):
        """Move a single block in a small grid using warping."""
        batch_size = 1
        image_height = 7
        image_width = 7
        channels = 3

        image = np.zeros([batch_size, image_height, image_width, channels])
        image[:, 3, 3, :] = 1.0
        input_image = tf.constant(image, dtype=type_to_use)

        # Place a control point at the one white pixel.
        control_point_locations = [[3.0, 3.0]]
        control_point_locations = tf.constant(np.float32(
            np.expand_dims(control_point_locations, 0)),
                                              dtype=type_to_use)
        # Shift it one pixel to the right.
        control_point_displacements = [[0.0, 1.0]]
        control_point_displacements = tf.constant(
            np.float32(np.expand_dims(control_point_displacements, 0)),
            dtype=type_to_use,
        )

        (warped_image, flow) = sparse_image_warp(
            input_image,
            control_point_locations,
            control_point_locations + control_point_displacements,
            interpolation_order=order,
            num_boundary_points=num_boundary_points,
        )

        warped_image, input_image, flow = self.evaluate(
            [warped_image, input_image, flow])
        # Check that it moved the pixel correctly.
        self.assertAllClose(warped_image[0, 4, 5, :],
                            input_image[0, 4, 4, :],
                            atol=1e-5,
                            rtol=1e-5)

        # Test that there is no flow at the corners.
        for i in (0, image_height - 1):
            for j in (0, image_width - 1):
                self.assertAllClose(flow[0, i, j, :],
                                    np.zeros([2]),
                                    atol=1e-5,
                                    rtol=1e-5)
Esempio n. 14
0
def time_warp(data, w=80):
    """Pick a random point along the time axis between 
    w and n_time_bins-w and warp by a distance
    between [0,w] towards the left or the right.

    Args:
        data: batch of spectrogram. Shape [n_examples, n_freq_bins, n_time_bins, n_channels] (NHWC)
        w: warp parameter (see above)
    """

    _, n_freq_bins, n_time_bins, _ = tf.shape(data)

    # pick a random point along the time axis in [w,n_time_bins-w]
    t = tf.random.uniform(shape=(),
                          minval=w,
                          maxval=n_time_bins - w,
                          dtype=tf.int32)

    # pick a random translation vector in [-w,w] along the time axis
    tv = tf.cast(
        tf.random.uniform(shape=(), minval=-w, maxval=w, dtype=tf.int32),
        tf.float32)

    # set control points y-coordinates
    ctl_pt_freqs = tf.convert_to_tensor([
        0.0,
        tf.cast(n_freq_bins, tf.float32) / 2.0,
        tf.cast(n_freq_bins - 1, tf.float32)
    ])

    # set source control point x-coordinates
    ctl_pt_times_src = tf.convert_to_tensor([t, t, t], dtype=tf.float32)

    # set destination control points
    ctl_pt_times_dst = ctl_pt_times_src + tv
    ctl_pt_src = tf.expand_dims(
        tf.stack([ctl_pt_freqs, ctl_pt_times_src], axis=-1), 0)
    ctl_pt_dst = tf.expand_dims(
        tf.stack([ctl_pt_freqs, ctl_pt_times_dst], axis=-1), 0)

    return sparse_image_warp(data,
                             ctl_pt_src,
                             ctl_pt_dst,
                             num_boundary_points=1)[0]
Esempio n. 15
0
    def time_warp(image):
        # Generating the start & end positions for warping.
        warp_start = TF.random.uniform([], 
                       minval = MAX_SHIFT, maxval = (W - MAX_SHIFT), dtype = TF.int32)
        warp_end = TF.random.uniform([], 
                     minval = -MAX_SHIFT, maxval = MAX_SHIFT, dtype = TF.int32) + warp_start

        # Generating the control points depicting the before & after states of the image.
        source = get_points(warp_start)
        destination = get_points(warp_end)

        # Adding dimensions to meet the requirements of the sparse_image_warp function.
        image = TF.expand_dims(image, 0)
        source = TF.expand_dims(source, 0)
        destination = TF.expand_dims(destination, 0)

        # Generating & Returning Warped Image
        warped_image, _ = sparse_image_warp(image, source, destination)
        return TF.squeeze(warped_image)
Esempio n. 16
0
    def sparse_warp(self, mel_spectrogram, time_warping_para=80):
        """Spec augmentation Calculation Function.

        # Arguments:
        mel_spectrogram(numpy array): audio file path of you want to warping and masking.
        time_warping_para(float): Augmentation parameter, "time warp parameter W".
            If none, default = 80 for LibriSpeech.

        # Returns
        mel_spectrogram(numpy array): warped and masked mel spectrogram.
        """

        fbank_size = tf.shape(mel_spectrogram)
        n, v = fbank_size[1], fbank_size[2]

        # Image warping control point setting.
        # Source
        pt = tf.random.uniform([], time_warping_para, n-time_warping_para, tf.int32) # radnom point along the time axis
        src_ctr_pt_freq = tf.range(v // 2)  # control points on freq-axis
        src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt  # control points on time-axis
        src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
        src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32)

        # Destination
        w = tf.random.uniform([], -time_warping_para, time_warping_para, tf.int32)  # distance
        dest_ctr_pt_freq = src_ctr_pt_freq
        dest_ctr_pt_time = src_ctr_pt_time + w
        if tf.math.reduce_any(dest_ctr_pt_time < 0):
            dest_ctr_pt_time = tf.repeat(0, v)
        if tf.math.reduce_any(dest_ctr_pt_time > n):
            dest_ctr_pt_time = tf.repeat(n, v)
        dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
        dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32)

        # warp
        source_control_point_locations = tf.expand_dims(src_ctr_pts, 0)  # (1, v//2, 2)
        dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0)  # (1, v//2, 2)

        warped_image, _ = sparse_image_warp(mel_spectrogram,
                                            source_control_point_locations,
                                            dest_control_point_locations)
        return warped_image
Esempio n. 17
0
def visualize_augment() -> None:
    """MAKEDOC: what is visualize_augment doing?"""
    logg = logging.getLogger(f"c.{__name__}.visualize_augment")
    # logg.setLevel("INFO")
    logg.debug("Start visualize_augment")

    # build a sample image to warp
    grid_stride = 8
    grid_w = 8
    grid_h = 8
    # grid_h = 4
    grid = np.zeros((grid_w * grid_stride, grid_h * grid_stride), dtype=np.float32)
    for x in range(grid_w):
        for y in range(grid_h):
            val = (x + 1) * (y + 1)
            xs = x * grid_stride
            xe = (x + 1) * grid_stride
            ys = y * grid_stride
            ye = (y + 1) * grid_stride
            grid[xs:xe, ys:ye] = val

    # word = "down"
    # which_fold = "training"
    # dataset_name = "mel04"
    # data_fol = Path("data_proc") / f"{dataset_name}"
    # word_aug_path = data_fol / f"{word}_{which_fold}.npy"
    # data = np.load(word_aug_path)
    # grid = data[0].T

    ################################
    #   Deterministic shift
    ################################

    # dx = grid_stride // 3
    # dy = grid_stride // 3
    # source_lnd = np.array(
    #     [
    #         [1 * grid_stride + dx, 1 * grid_stride + dy],
    #         [1 * grid_stride + dx, (grid_h - 2) * grid_stride + dy],
    #         [(grid_w - 2) * grid_stride + dx, 1 * grid_stride + dy],
    #         [(grid_w - 2) * grid_stride + dx, (grid_h - 2) * grid_stride + dy],
    #     ],
    #     dtype=np.float32,
    # )
    # logg.debug(f"source_lnd.shape: {source_lnd.shape}")

    # dw = grid_stride // 2
    # delta_land = np.array([[dw, dw], [dw, dw], [dw, dw], [dw, dw]], dtype=np.float32,)

    # dest_lnd = source_lnd + delta_land
    # logg.debug(f"dest_lnd:\n{dest_lnd}")

    # # add the batch dimension
    # grid_b = np.expand_dims(grid, axis=0)
    # logg.debug(f"grid_b.shape: {grid_b.shape}")
    # source_lnd_b = np.expand_dims(source_lnd, axis=0)
    # logg.debug(f"source_lnd_b.shape: {source_lnd_b.shape}")
    # dest_lnd_b = np.expand_dims(dest_lnd, axis=0)
    # grid_b = np.expand_dims(grid_b, axis=-1)
    # logg.debug(f"grid_b.shape: {grid_b.shape}")

    # # warp the image
    # grid_warped_b, _ = sparse_image_warp(
    #     grid_b, source_lnd_b, dest_lnd_b, num_boundary_points=2
    # )

    # logg.debug(f"grid_warped_b.shape: {grid_warped_b.shape}")

    # # extract the single image
    # grid_warped = grid_warped_b[0][:, :, 0].numpy()
    # logg.debug(f"grid_warped.shape: {grid_warped.shape}")

    # # plot all the results
    # fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(8, 10))

    # im_args = {"origin": "lower", "cmap": "YlOrBr"}
    # axes[0].imshow(grid.T, **im_args)
    # axes[1].imshow(grid_warped.T, **im_args)

    # pl_args = {"linestyle": "none", "color": "c", "markersize": 8}
    # pl_args["marker"] = "d"
    # axes[0].plot(*source_lnd.T, label="Source landmarks", **pl_args)
    # # axes[0].legend()
    # axes[0].set_title("Original image")

    # # pl_args["marker"] = "^"
    # axes[1].plot(*dest_lnd.T, label="Dest landmarks", **pl_args)
    # # axes[1].legend()
    # axes[1].set_title("Warped image")

    # fig.tight_layout()
    # plt.show()

    ################################
    #   Random shift
    ################################

    rng = np.random.default_rng()
    num_landmarks = 4
    max_warp_time = 2
    max_warp_freq = 2
    # num_landmarks = 3
    # max_warp_time = 5
    # max_warp_freq = 5
    num_samples = 1
    spec_dim = grid.shape

    # expand the grid with batch and channel dimension
    grid_b = np.expand_dims(grid, axis=0)
    grid_b = np.expand_dims(grid_b, axis=-1)

    # the shape of the landmark for one dimension
    land_shape = num_samples, num_landmarks

    # the source point has to be at least max_warp_* from the border
    bounds_time = (max_warp_time, spec_dim[0] - max_warp_time)
    bounds_freq = (max_warp_freq, spec_dim[1] - max_warp_freq)

    # generate (num_sample, num_landmarks) time/freq positions
    source_land_t = rng.uniform(*bounds_time, size=land_shape).astype(np.float32)
    source_land_f = rng.uniform(*bounds_freq, size=land_shape).astype(np.float32)
    source_lnd = np.dstack((source_land_t, source_land_f))
    logg.debug(f"land_t.shape: {source_land_t.shape}")
    logg.debug(f"source_lnd.shape: {source_lnd.shape}")

    # generate the deltas, how much to shift each point
    delta_t = rng.uniform(-max_warp_time, max_warp_time, size=land_shape)
    delta_f = rng.uniform(-max_warp_freq, max_warp_freq, size=land_shape)
    dest_land_t = source_land_t + delta_t
    dest_land_f = source_land_f + delta_f
    dest_lnd = np.dstack((dest_land_t, dest_land_f)).astype(np.float32)
    logg.debug(f"dest_lnd.shape: {dest_lnd.shape}")

    # data_specs = data_specs.astype("float32")
    # source_lnd = source_lnd.astype("float32")
    # dest_lnd = dest_lnd.astype("float32")
    data_warped, _ = sparse_image_warp(
        grid_b, source_lnd, dest_lnd, num_boundary_points=2
    )
    logg.debug(f"data_warped.shape: {data_warped.shape}")

    # data_specs = tf.convert_to_tensor(specs, dtype=tf.float32)
    # source_lnd = tf.convert_to_tensor(source_lnd, dtype=tf.float32)
    # dest_lnd = tf.convert_to_tensor(dest_lnd, dtype=tf.float32)
    # siw = tf.function(sparse_image_warp, experimental_relax_shapes=True)
    # data_warped, _ = siw(
    #     data_specs, source_lnd, dest_lnd, num_boundary_points=2
    # # )
    # logg.debug(f"data_warped.shape: {data_warped.shape}")

    grid_warped_b = warp_spectrograms(
        grid_b, num_landmarks, max_warp_time, max_warp_freq, rng
    )

    logg.debug(f"grid_warped_b.shape: {grid_warped_b.shape}")

    # extract the single image
    grid_warped = grid_warped_b[0][:, :, 0].numpy()
    logg.debug(f"grid_warped.shape: {grid_warped.shape}")

    # plot all the results
    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(8, 8))

    im_args = {"origin": "lower", "cmap": "YlOrBr"}
    axes[0].imshow(grid.T, **im_args)
    axes[1].imshow(grid_warped.T, **im_args)

    pl_args = {"linestyle": "none", "color": "c", "markersize": 8}
    pl_args["marker"] = "d"
    axes[0].plot(*source_lnd.T, label="Source landmarks", **pl_args)
    # axes[0].legend()
    axes[0].set_title("Original image")

    # pl_args["marker"] = "^"
    axes[1].plot(*dest_lnd.T, label="Dest landmarks", **pl_args)
    # axes[1].legend()
    axes[1].set_title("Warped image")

    fig.tight_layout()

    plot_folder = Path("plot_models")
    fig.savefig(plot_folder / "warp_grid.pdf")

    plt.show()