def time_warp(self): # Reshape to [Batch_size, time, freq, 1] for sparse_image_warp func. self.mel_spectrogram = np.reshape(self.mel_spectrogram, (-1, self.mel_spectrogram.shape[0], self.mel_spectrogram.shape[1], 1)) v, tau = self.mel_spectrogram.shape[1], self.mel_spectrogram.shape[2] horiz_line_thru_ctr = self.mel_spectrogram[0][v // 2] random_pt = horiz_line_thru_ctr[random.randrange( self.W, tau - self.W)] # random point along the horizontal/time axis w = np.random.uniform((-self.W), self.W) # distance # Source Points src_points = [[[v // 2, random_pt[0]]]] # Destination Points dest_points = [[[v // 2, random_pt[0] + w]]] self.mel_spectrogram, _ = sparse_image_warp(self.mel_spectrogram, src_points, dest_points, num_boundary_points=2) return self.mel_spectrogram
def assert_zero_shift(order, regularization, num_boundary_points): """Check that warping with zero displacements doesn't change the image.""" batch_size = 1 image_height = 4 image_width = 4 channels = 3 image = np.random.uniform( size=[batch_size, image_height, image_width, channels]) input_image = tf.constant(np.float32(image)) control_point_locations = [[1.0, 1.0], [2.0, 2.0], [2.0, 1.0]] control_point_locations = tf.constant( np.float32(np.expand_dims(control_point_locations, 0))) control_point_displacements = np.zeros( control_point_locations.shape.as_list()) control_point_displacements = tf.constant( np.float32(control_point_displacements)) (warped_image, _) = sparse_image_warp( input_image, control_point_locations, control_point_locations + control_point_displacements, interpolation_order=order, regularization_weight=regularization, num_boundary_points=num_boundary_points, ) np.testing.assert_allclose(warped_image, input_image, rtol=1e-6, atol=1e-6)
def test_that_backprop_runs(): """Making sure the gradients can be computed.""" batch_size = 1 image_height = 9 image_width = 12 image = tf.Variable( np.random.uniform(size=[batch_size, image_height, image_width, 3]), dtype=tf.float32, ) control_point_locations = [[3.0, 3.0]] control_point_locations = tf.constant( np.float32(np.expand_dims(control_point_locations, 0))) control_point_displacements = [[0.25, -0.5]] control_point_displacements = tf.constant( np.float32(np.expand_dims(control_point_displacements, 0))) with tf.GradientTape() as t: warped_image, _ = sparse_image_warp( image, control_point_locations, control_point_locations + control_point_displacements, num_boundary_points=3, ) gradients = t.gradient(warped_image, image).numpy() assert np.sum(np.abs(gradients)) != 0
def testThatBackpropRuns(self): """Run optimization to ensure that gradients can be computed.""" self.skipTest("TODO: port to tf2.0 / eager") batch_size = 1 image_height = 9 image_width = 12 image = tf.Variable( np.float32( np.random.uniform( size=[batch_size, image_height, image_width, 3]))) control_point_locations = [[3., 3.]] control_point_locations = tf.constant( np.float32(np.expand_dims(control_point_locations, 0))) control_point_displacements = [[0.25, -0.5]] control_point_displacements = tf.constant( np.float32(np.expand_dims(control_point_displacements, 0))) warped_image, _ = sparse_image_warp(image, control_point_locations, control_point_locations + control_point_displacements, num_boundary_points=3) loss = tf.reduce_mean(tf.abs(warped_image - image)) optimizer = tf1.train.MomentumOptimizer(0.001, 0.9) grad = tf.gradients(loss, [image]) grad, _ = tf.clip_by_global_norm(grad, 1.0) opt_func = optimizer.apply_gradients(zip(grad, [image])) init_op = tf1.variables.global_variables_initializer( ) # TODO: fix TF1 ref. with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run([loss, opt_func])
def sparse_warp(mel_spectrogram, time_warping_para=80): fbank_size = tf.shape(mel_spectrogram) n, v = fbank_size[1], fbank_size[2] # Image warping control point setting. # Source pt = tf.random.uniform([], time_warping_para, n - time_warping_para, tf.int32) # radnom point along the time axis src_ctr_pt_freq = tf.range(v // 2) # control points on freq-axis src_ctr_pt_time = tf.ones_like( src_ctr_pt_freq) * pt # control points on time-axis src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1) src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32) # Destination w = tf.random.uniform([], -time_warping_para, time_warping_para, tf.int32) # distance dest_ctr_pt_freq = src_ctr_pt_freq dest_ctr_pt_time = src_ctr_pt_time + w dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1) dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32) # warp source_control_point_locations = tf.expand_dims(src_ctr_pts, 0) # (1, v//2, 2) dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) # (1, v//2, 2) warped_image, _ = sparse_image_warp(mel_spectrogram, source_control_point_locations, dest_control_point_locations) return warped_image
def test_partially_or_fully_unknown_shape(shape, interpolation_order, num_boundary_points): control_point_locations = np.asarray([3.0, 3.0]).reshape(1, 1, 2).astype(np.float32) control_point_displacements = (np.asarray([0.25, -0.5]).reshape( 1, 1, 2).astype(np.float32)) fn = tf.function(sparse_image_warp).get_concrete_function( image=tf.TensorSpec(shape=shape.image, dtype=tf.float32), source_control_point_locations=tf.TensorSpec( shape=shape.source_control_point_locations, dtype=tf.float32), dest_control_point_locations=tf.TensorSpec( shape=shape.dest_control_point_locations, dtype=tf.float32), interpolation_order=interpolation_order, num_boundary_points=num_boundary_points, ) image = tf.ones(shape=shape.input, dtype=tf.float32) expected_output = sparse_image_warp( image, control_point_locations, control_point_locations + control_point_displacements, interpolation_order=interpolation_order, num_boundary_points=num_boundary_points, ) output = fn( image, control_point_locations, control_point_locations + control_point_displacements, interpolation_order=interpolation_order, num_boundary_points=num_boundary_points, ) np.testing.assert_equal(output[0].numpy(), expected_output[0].numpy()) np.testing.assert_equal(output[1].numpy(), expected_output[1].numpy())
def time_warp(spectrogram, W=80): # Returns the time-warped tensor given the spectrogram tau, f = spectrogram.shape # Source control point locations point = tf.random.uniform(shape = [], minval = W, maxval = tau - W, dtype = tf.int32) freq_at_point = tf.range(f//2) # The column of the spectorgram at point time_at_point = tf.ones_like(freq_at_point, dtype=tf.int32)*point # control points on the time axis scpt = tf.cast(tf.stack((freq_at_point, time_at_point), axis = -1), dtype = tf.float32) scpt = tf.expand_dims(scpt, axis = 0) # Destination control point locations dt = tf.random.uniform(shape = [], minval = -W, maxval = W, dtype = tf.int32) dest_freq_at_point = freq_at_point dest_time_at_point = time_at_point + dt dcpt = tf.cast(tf.stack((dest_freq_at_point, dest_time_at_point), axis = -1), dtype = tf.float32) dcpt = tf.expand_dims(dcpt, axis = 0) spect = tf.cast(tf.reshape(spectrogram, [1, *spectrogram.shape, 1]), dtype = tf.float32) warped_dat, _ = sparse_image_warp(spect, source_control_point_locations = scpt, dest_control_point_locations = dcpt, num_boundary_points=2) # Need to see if there is any way to have 1.5-equivalent warped_dat = tf.reshape(warped_dat, spectrogram.shape) return warped_dat
def assertZeroShift(self, order, regularization, num_boundary_points): """Check that warping with zero displacements doesn't change the image.""" batch_size = 1 image_height = 4 image_width = 4 channels = 3 image = np.random.uniform( size=[batch_size, image_height, image_width, channels]) input_image = tf.constant(np.float32(image)) control_point_locations = [[1., 1.], [2., 2.], [2., 1.]] control_point_locations = tf.constant( np.float32(np.expand_dims(control_point_locations, 0))) control_point_displacements = np.zeros( control_point_locations.shape.as_list()) control_point_displacements = tf.constant( np.float32(control_point_displacements)) (warped_image, flow) = sparse_image_warp( input_image, control_point_locations, control_point_locations + control_point_displacements, interpolation_order=order, regularization_weight=regularization, num_boundary_points=num_boundary_points) warped_image, input_image = self.evaluate([warped_image, input_image]) self.assertAllClose(warped_image, input_image)
def test_smiley_face(): """Check warping accuracy by comparing to hardcoded warped images.""" input_file = get_path_to_datafile( "image/tests/test_data/Yellow_Smiley_Face.png") input_image = load_image(input_file) control_points = np.asarray([ [64, 59], [180 - 64, 59], [39, 111], [180 - 39, 111], [90, 143], [58, 134], [180 - 58, 134], ]) # pyformat: disable control_point_displacements = np.asarray([ [-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25], [10, 10.75], ]) control_points = tf.constant( np.expand_dims(np.float32(control_points[:, [1, 0]]), 0)) control_point_displacements = tf.constant( np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0)) float_image = np.expand_dims(np.float32(input_image) / 255, 0) input_image = tf.constant(float_image) for interpolation_order in (1, 2, 3): for num_boundary_points in (0, 1, 4): warped_image, _ = sparse_image_warp( input_image, control_points, control_points + control_point_displacements, interpolation_order=interpolation_order, num_boundary_points=num_boundary_points, ) warped_image = warped_image out_image = np.uint8(warped_image[0, :, :, :] * 255) target_file = get_path_to_datafile( "image/tests/test_data/Yellow_Smiley_Face_Warp-interp" + "-{}-clamp-{}.png".format(interpolation_order, num_boundary_points)) target_image = load_image(target_file) # Check that the target_image and out_image difference is no # bigger than 2 (on a scale of 0-255). Due to differences in # floating point computation on different devices, the float # output in warped_image may get rounded to a different int # than that in the saved png file loaded into target_image. np.testing.assert_allclose(target_image, out_image, atol=2, rtol=1e-3)
def loss_fn(): warped_image, _ = sparse_image_warp(image, control_point_locations, control_point_locations + control_point_displacements, num_boundary_points=3) loss = tf.reduce_mean(tf.abs(warped_image - image)) return loss
def sparse_warp(mel_spectrogram, time_warping_para=80): print("时间扭曲") """ # 参数: mel_spectrogram(numpy array): 你想要扭曲和屏蔽的音频文件路径. time_warping_para(float): 增强参数, "时间扭曲参数 W". 如果为none, 对于LibriSpeech数据集默认为80. # Returns mel_spectrogram(numpy array): 扭曲和掩蔽后的梅尔频谱图. τ个时间步的log mel 频谱图 (W,τ-W)范围内的随机点,向左或向右平移w距离,w从(0,W)的均匀分布中挑出。 边界上有六个固定点,W是time warp parameter """ fbank_size = tf.shape(mel_spectrogram) n, v = fbank_size[1], fbank_size[2] print("n: {}".format(n)) print("v: {}".format(v)) print("time_warping_para: {}".format(time_warping_para)) """ n: 256 v: 92 time_warping_para: 80 pt: 105 """ # n为该频谱图的时间步 # v为该频谱图的频率 # 步骤1 : 时间扭曲 # 图像扭曲控制点设置。 # 源 pt = tf.random.uniform([], time_warping_para, n - time_warping_para, tf.int32) # radnom point along the time axis print("pt: {}".format(pt)) # (80,176)之间的一个随机数 src_ctr_pt_freq = tf.range(v // 2) # [0,46)的一系列数字control points on freq-axis src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt # 返回一个全为1,形状相同的张量,control points on time-axis # src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1) src_ctr_pts = tf.stack((src_ctr_pt_freq, src_ctr_pt_time), -1) src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32) # 目标 w = tf.random.uniform([], 0, time_warping_para, tf.int32) # distance # ??? # w = tf.random.uniform([], -time_warping_para, time_warping_para, tf.int32) # distance dest_ctr_pt_freq = src_ctr_pt_freq dest_ctr_pt_time = src_ctr_pt_time + w # dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1) dest_ctr_pts = tf.stack((dest_ctr_pt_freq, dest_ctr_pt_time), -1) dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32) # 扭曲 source_control_point_locations = tf.expand_dims(src_ctr_pts, 0) # (1, v//2, 2) dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) # (1, v//2, 2) warped_image, _ = sparse_image_warp(mel_spectrogram, source_control_point_locations, dest_control_point_locations) return warped_image
def sparse_warp(mel_spectrogram, time_warping_para=80): """Spec augmentation Calculation Function. 'SpecAugment' have 3 steps for audio data augmentation. first step is time warping using Tensorflow's image_sparse_warp function. Second step is frequency masking, last step is time masking. # Arguments: mel_spectrogram(numpy array): audio file path of you want to warping and masking. time_warping_para(float): Augmentation parameter, "time warp parameter W". If none, default = 80 for LibriSpeech. # Returns mel_spectrogram(numpy array): warped and masked mel spectrogram. """ fbank_size = tf.shape(mel_spectrogram) n, v = fbank_size[1], fbank_size[2] # Step 1 : Time warping # Image warping control point setting. # Source pt = tf.random.uniform([], time_warping_para, n - time_warping_para, tf.int32) # radnom point along the time axis src_ctr_pt_freq = tf.range(v // 2) # control points on freq-axis src_ctr_pt_time = tf.ones_like( src_ctr_pt_freq) * pt # control points on time-axis src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1) src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32) # Destination w = tf.random.uniform([], -time_warping_para, time_warping_para, tf.int32) # distance dest_ctr_pt_freq = src_ctr_pt_freq dest_ctr_pt_time = src_ctr_pt_time + w dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1) dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32) # warp source_control_point_locations = tf.expand_dims(src_ctr_pts, 0) # (1, v//2, 2) dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) # (1, v//2, 2) warped_image, _ = sparse_image_warp(tf.convert_to_tensor(mel_spectrogram), source_control_point_locations, dest_control_point_locations, regularization_weight=1e-5, num_boundary_points=1) return warped_image.numpy()
def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use): """Move a single block in a small grid using warping.""" batch_size = 1 image_height = 7 image_width = 7 channels = 3 image = np.zeros([batch_size, image_height, image_width, channels]) image[:, 3, 3, :] = 1.0 input_image = tf.constant(image, dtype=type_to_use) # Place a control point at the one white pixel. control_point_locations = [[3.0, 3.0]] control_point_locations = tf.constant(np.float32( np.expand_dims(control_point_locations, 0)), dtype=type_to_use) # Shift it one pixel to the right. control_point_displacements = [[0.0, 1.0]] control_point_displacements = tf.constant( np.float32(np.expand_dims(control_point_displacements, 0)), dtype=type_to_use, ) (warped_image, flow) = sparse_image_warp( input_image, control_point_locations, control_point_locations + control_point_displacements, interpolation_order=order, num_boundary_points=num_boundary_points, ) warped_image, input_image, flow = self.evaluate( [warped_image, input_image, flow]) # Check that it moved the pixel correctly. self.assertAllClose(warped_image[0, 4, 5, :], input_image[0, 4, 4, :], atol=1e-5, rtol=1e-5) # Test that there is no flow at the corners. for i in (0, image_height - 1): for j in (0, image_width - 1): self.assertAllClose(flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5)
def time_warp(data, w=80): """Pick a random point along the time axis between w and n_time_bins-w and warp by a distance between [0,w] towards the left or the right. Args: data: batch of spectrogram. Shape [n_examples, n_freq_bins, n_time_bins, n_channels] (NHWC) w: warp parameter (see above) """ _, n_freq_bins, n_time_bins, _ = tf.shape(data) # pick a random point along the time axis in [w,n_time_bins-w] t = tf.random.uniform(shape=(), minval=w, maxval=n_time_bins - w, dtype=tf.int32) # pick a random translation vector in [-w,w] along the time axis tv = tf.cast( tf.random.uniform(shape=(), minval=-w, maxval=w, dtype=tf.int32), tf.float32) # set control points y-coordinates ctl_pt_freqs = tf.convert_to_tensor([ 0.0, tf.cast(n_freq_bins, tf.float32) / 2.0, tf.cast(n_freq_bins - 1, tf.float32) ]) # set source control point x-coordinates ctl_pt_times_src = tf.convert_to_tensor([t, t, t], dtype=tf.float32) # set destination control points ctl_pt_times_dst = ctl_pt_times_src + tv ctl_pt_src = tf.expand_dims( tf.stack([ctl_pt_freqs, ctl_pt_times_src], axis=-1), 0) ctl_pt_dst = tf.expand_dims( tf.stack([ctl_pt_freqs, ctl_pt_times_dst], axis=-1), 0) return sparse_image_warp(data, ctl_pt_src, ctl_pt_dst, num_boundary_points=1)[0]
def time_warp(image): # Generating the start & end positions for warping. warp_start = TF.random.uniform([], minval = MAX_SHIFT, maxval = (W - MAX_SHIFT), dtype = TF.int32) warp_end = TF.random.uniform([], minval = -MAX_SHIFT, maxval = MAX_SHIFT, dtype = TF.int32) + warp_start # Generating the control points depicting the before & after states of the image. source = get_points(warp_start) destination = get_points(warp_end) # Adding dimensions to meet the requirements of the sparse_image_warp function. image = TF.expand_dims(image, 0) source = TF.expand_dims(source, 0) destination = TF.expand_dims(destination, 0) # Generating & Returning Warped Image warped_image, _ = sparse_image_warp(image, source, destination) return TF.squeeze(warped_image)
def sparse_warp(self, mel_spectrogram, time_warping_para=80): """Spec augmentation Calculation Function. # Arguments: mel_spectrogram(numpy array): audio file path of you want to warping and masking. time_warping_para(float): Augmentation parameter, "time warp parameter W". If none, default = 80 for LibriSpeech. # Returns mel_spectrogram(numpy array): warped and masked mel spectrogram. """ fbank_size = tf.shape(mel_spectrogram) n, v = fbank_size[1], fbank_size[2] # Image warping control point setting. # Source pt = tf.random.uniform([], time_warping_para, n-time_warping_para, tf.int32) # radnom point along the time axis src_ctr_pt_freq = tf.range(v // 2) # control points on freq-axis src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt # control points on time-axis src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1) src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32) # Destination w = tf.random.uniform([], -time_warping_para, time_warping_para, tf.int32) # distance dest_ctr_pt_freq = src_ctr_pt_freq dest_ctr_pt_time = src_ctr_pt_time + w if tf.math.reduce_any(dest_ctr_pt_time < 0): dest_ctr_pt_time = tf.repeat(0, v) if tf.math.reduce_any(dest_ctr_pt_time > n): dest_ctr_pt_time = tf.repeat(n, v) dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1) dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32) # warp source_control_point_locations = tf.expand_dims(src_ctr_pts, 0) # (1, v//2, 2) dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) # (1, v//2, 2) warped_image, _ = sparse_image_warp(mel_spectrogram, source_control_point_locations, dest_control_point_locations) return warped_image
def visualize_augment() -> None: """MAKEDOC: what is visualize_augment doing?""" logg = logging.getLogger(f"c.{__name__}.visualize_augment") # logg.setLevel("INFO") logg.debug("Start visualize_augment") # build a sample image to warp grid_stride = 8 grid_w = 8 grid_h = 8 # grid_h = 4 grid = np.zeros((grid_w * grid_stride, grid_h * grid_stride), dtype=np.float32) for x in range(grid_w): for y in range(grid_h): val = (x + 1) * (y + 1) xs = x * grid_stride xe = (x + 1) * grid_stride ys = y * grid_stride ye = (y + 1) * grid_stride grid[xs:xe, ys:ye] = val # word = "down" # which_fold = "training" # dataset_name = "mel04" # data_fol = Path("data_proc") / f"{dataset_name}" # word_aug_path = data_fol / f"{word}_{which_fold}.npy" # data = np.load(word_aug_path) # grid = data[0].T ################################ # Deterministic shift ################################ # dx = grid_stride // 3 # dy = grid_stride // 3 # source_lnd = np.array( # [ # [1 * grid_stride + dx, 1 * grid_stride + dy], # [1 * grid_stride + dx, (grid_h - 2) * grid_stride + dy], # [(grid_w - 2) * grid_stride + dx, 1 * grid_stride + dy], # [(grid_w - 2) * grid_stride + dx, (grid_h - 2) * grid_stride + dy], # ], # dtype=np.float32, # ) # logg.debug(f"source_lnd.shape: {source_lnd.shape}") # dw = grid_stride // 2 # delta_land = np.array([[dw, dw], [dw, dw], [dw, dw], [dw, dw]], dtype=np.float32,) # dest_lnd = source_lnd + delta_land # logg.debug(f"dest_lnd:\n{dest_lnd}") # # add the batch dimension # grid_b = np.expand_dims(grid, axis=0) # logg.debug(f"grid_b.shape: {grid_b.shape}") # source_lnd_b = np.expand_dims(source_lnd, axis=0) # logg.debug(f"source_lnd_b.shape: {source_lnd_b.shape}") # dest_lnd_b = np.expand_dims(dest_lnd, axis=0) # grid_b = np.expand_dims(grid_b, axis=-1) # logg.debug(f"grid_b.shape: {grid_b.shape}") # # warp the image # grid_warped_b, _ = sparse_image_warp( # grid_b, source_lnd_b, dest_lnd_b, num_boundary_points=2 # ) # logg.debug(f"grid_warped_b.shape: {grid_warped_b.shape}") # # extract the single image # grid_warped = grid_warped_b[0][:, :, 0].numpy() # logg.debug(f"grid_warped.shape: {grid_warped.shape}") # # plot all the results # fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(8, 10)) # im_args = {"origin": "lower", "cmap": "YlOrBr"} # axes[0].imshow(grid.T, **im_args) # axes[1].imshow(grid_warped.T, **im_args) # pl_args = {"linestyle": "none", "color": "c", "markersize": 8} # pl_args["marker"] = "d" # axes[0].plot(*source_lnd.T, label="Source landmarks", **pl_args) # # axes[0].legend() # axes[0].set_title("Original image") # # pl_args["marker"] = "^" # axes[1].plot(*dest_lnd.T, label="Dest landmarks", **pl_args) # # axes[1].legend() # axes[1].set_title("Warped image") # fig.tight_layout() # plt.show() ################################ # Random shift ################################ rng = np.random.default_rng() num_landmarks = 4 max_warp_time = 2 max_warp_freq = 2 # num_landmarks = 3 # max_warp_time = 5 # max_warp_freq = 5 num_samples = 1 spec_dim = grid.shape # expand the grid with batch and channel dimension grid_b = np.expand_dims(grid, axis=0) grid_b = np.expand_dims(grid_b, axis=-1) # the shape of the landmark for one dimension land_shape = num_samples, num_landmarks # the source point has to be at least max_warp_* from the border bounds_time = (max_warp_time, spec_dim[0] - max_warp_time) bounds_freq = (max_warp_freq, spec_dim[1] - max_warp_freq) # generate (num_sample, num_landmarks) time/freq positions source_land_t = rng.uniform(*bounds_time, size=land_shape).astype(np.float32) source_land_f = rng.uniform(*bounds_freq, size=land_shape).astype(np.float32) source_lnd = np.dstack((source_land_t, source_land_f)) logg.debug(f"land_t.shape: {source_land_t.shape}") logg.debug(f"source_lnd.shape: {source_lnd.shape}") # generate the deltas, how much to shift each point delta_t = rng.uniform(-max_warp_time, max_warp_time, size=land_shape) delta_f = rng.uniform(-max_warp_freq, max_warp_freq, size=land_shape) dest_land_t = source_land_t + delta_t dest_land_f = source_land_f + delta_f dest_lnd = np.dstack((dest_land_t, dest_land_f)).astype(np.float32) logg.debug(f"dest_lnd.shape: {dest_lnd.shape}") # data_specs = data_specs.astype("float32") # source_lnd = source_lnd.astype("float32") # dest_lnd = dest_lnd.astype("float32") data_warped, _ = sparse_image_warp( grid_b, source_lnd, dest_lnd, num_boundary_points=2 ) logg.debug(f"data_warped.shape: {data_warped.shape}") # data_specs = tf.convert_to_tensor(specs, dtype=tf.float32) # source_lnd = tf.convert_to_tensor(source_lnd, dtype=tf.float32) # dest_lnd = tf.convert_to_tensor(dest_lnd, dtype=tf.float32) # siw = tf.function(sparse_image_warp, experimental_relax_shapes=True) # data_warped, _ = siw( # data_specs, source_lnd, dest_lnd, num_boundary_points=2 # # ) # logg.debug(f"data_warped.shape: {data_warped.shape}") grid_warped_b = warp_spectrograms( grid_b, num_landmarks, max_warp_time, max_warp_freq, rng ) logg.debug(f"grid_warped_b.shape: {grid_warped_b.shape}") # extract the single image grid_warped = grid_warped_b[0][:, :, 0].numpy() logg.debug(f"grid_warped.shape: {grid_warped.shape}") # plot all the results fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(8, 8)) im_args = {"origin": "lower", "cmap": "YlOrBr"} axes[0].imshow(grid.T, **im_args) axes[1].imshow(grid_warped.T, **im_args) pl_args = {"linestyle": "none", "color": "c", "markersize": 8} pl_args["marker"] = "d" axes[0].plot(*source_lnd.T, label="Source landmarks", **pl_args) # axes[0].legend() axes[0].set_title("Original image") # pl_args["marker"] = "^" axes[1].plot(*dest_lnd.T, label="Dest landmarks", **pl_args) # axes[1].legend() axes[1].set_title("Warped image") fig.tight_layout() plot_folder = Path("plot_models") fig.savefig(plot_folder / "warp_grid.pdf") plt.show()