def test_unpad_1(self): from blusky.transforms.cascade_tree import CascadeTree inp = Input(shape=(99, 99, 1)) padded = self.pad_2d.pad(inp) cascade_tree = CascadeTree(padded, order=2) cascade_tree.generate(self.wavelets, self.cascade._convolve) convs = cascade_tree.get_convolutions() cascade_tree = CascadeTree(padded, order=2) cascade_tree.generate(self.wavelets, self.pad_2d.unpad) unpadded = cascade_tree.get_convolutions() unpadded_convs = [i[1](i[0]) for i in zip(convs, unpadded)] # The reason the shape is changing is because its being decimated. self.assertTrue(unpadded_convs[0].shape[1] == 99) self.assertTrue(unpadded_convs[0].shape[2] == 99) self.assertTrue(unpadded_convs[1].shape[1] == 50) self.assertTrue(unpadded_convs[1].shape[2] == 50) self.assertTrue(unpadded_convs[2].shape[1] == 25) self.assertTrue(unpadded_convs[2].shape[2] == 25) self.assertTrue(unpadded_convs[3].shape[1] == 50) self.assertTrue(unpadded_convs[3].shape[2] == 50) self.assertTrue(unpadded_convs[4].shape[1] == 25) self.assertTrue(unpadded_convs[4].shape[2] == 25) self.assertTrue(unpadded_convs[5].shape[1] == 25) self.assertTrue(unpadded_convs[5].shape[2] == 25)
def test_cascade_1d_results(self): # vanilla filter bank wavelets = [ vanilla_morlet_1d(self.sample_rate, self.J, j=i) for i in range(0, self.J) ] deci = NoDecimation() inp = Input(shape=(self.N, 1)) # pad pad_1d = Pad1D(wavelets, decimation=deci) padded = pad_1d.pad(inp) # cascade_tree = CascadeTree(padded, order=self.order) cascade = Cascade1D(decimation=deci) convs = cascade.transform(cascade_tree, wavelets=wavelets) # Create layers to remove padding cascade_tree = CascadeTree(padded, order=self.order) cascade_tree.generate(wavelets, pad_1d._unpad_same) unpad = cascade_tree.get_convolutions() # Remove the padding unpadded_convs = [i[1](i[0]) for i in zip(convs, unpad)] model = Model(inputs=inp, outputs=unpadded_convs) result = model.predict(np.expand_dims(self.ts, axis=0)) cnn_result_1 = np.squeeze(result[0]) cnn_result_2 = np.squeeze(result[5]) cnn_result_3 = np.squeeze(result[-10]) np.testing.assert_allclose( self.conv1, cnn_result_1, atol=1e-3, err_msg="first order does not match with cnn result.", ) np.testing.assert_allclose( self.conv2, cnn_result_2, atol=1e-3, err_msg="first order does not match with cnn result.", ) np.testing.assert_allclose( self.conv3, cnn_result_3, atol=1e-3, err_msg="first order does not match with cnn result.", )
def test_unpad_2(self): from blusky.transforms.cascade_tree import CascadeTree # --- Pad/Unpad --- # inp = Input(shape=(99, 99, 1)) padded = self.pad_2d.pad(inp) cascade_tree = CascadeTree(padded, order=2) cascade_tree.generate(self.wavelets, self.cascade._convolve) convs = cascade_tree.get_convolutions() cascade_tree = CascadeTree(padded, order=2) cascade_tree.generate(self.wavelets, self.pad_2d.unpad) unpadded = cascade_tree.get_convolutions() unpadded_convs = [i[1](i[0]) for i in zip(convs, unpadded)] model_1 = Model(inputs=inp, outputs=unpadded_convs) result_1 = model_1.predict(self.imgs[:1]) # --- Manually --- # inp = Input(shape=(131, 131, 1)) cascade_tree = CascadeTree(inp, order=2) cascade_tree.generate(self.wavelets, self.cascade._convolve) convs = cascade_tree.get_convolutions() model_2 = Model(inputs=inp, outputs=convs) padded = np.pad( np.array(self.imgs[:1]), ((0, 0), (16, 16), (16, 16), (0, 0)), "reflect", ) result_2 = model_2.predict(padded) result_2[0] = result_2[0][:, 16:-16, 16:-16, :] np.allclose(result_1[0], result_2[0])
def test_default_decimation(self): """ test the default decimation method. At the first layer we decimate according to scale. At subsequent layers, we decimate according to the difference in order with the current and previuous layers. """ import numpy as np class test_wav(HasStrictTraits): scale = Int(1) cascade_tree = CascadeTree(Input(shape=(64, 64, 1)), order=4) cascade_tree.generate( [ test_wav(scale=1), test_wav(scale=2), test_wav(scale=3), test_wav(scale=4), ], lambda x, y, z: x, ) # go down a path in the tree root_node = cascade_tree.root_node first_layer = root_node.children[0] second_layer = first_layer.children[0] third_layer = second_layer.children[0] deci = DefaultDecimation(oversampling=0) wav = np.empty((64, 64)) wavp = deci.decimate_wavelet(wav, 2) np.array_equal(wav[::2, ::2], wavp) # if I would decimate wavelet/conv deci = DefaultDecimation(oversampling=0) w, c = deci.resolve_scales(root_node) self.assertTrue(w == 1 and c == 1) w, c = deci.resolve_scales(first_layer) self.assertTrue(w == 1 and c == 2) w, c = deci.resolve_scales(second_layer) self.assertTrue(w == 2 and c == 2) w, c = deci.resolve_scales(third_layer) self.assertTrue(w == 4 and c == 2) # deci = DefaultDecimation(oversampling=1) w, c = deci.resolve_scales(root_node) self.assertTrue(w == 1 and c == 1) w, c = deci.resolve_scales(first_layer) self.assertTrue(w == 1 and c == 1) w, c = deci.resolve_scales(second_layer) self.assertTrue(w == 1 and c == 2) w, c = deci.resolve_scales(third_layer) self.assertTrue(w == 2 and c == 2) # deci = DefaultDecimation(oversampling=2) w, c = deci.resolve_scales(root_node) self.assertTrue(w == 1 and c == 1) w, c = deci.resolve_scales(first_layer) self.assertTrue(w == 1 and c == 1) w, c = deci.resolve_scales(second_layer) self.assertTrue(w == 1 and c == 1) w, c = deci.resolve_scales(third_layer) self.assertTrue(w == 1 and c == 2) deci = DefaultDecimation(oversampling=0) root_node = cascade_tree.root_node first_layer = root_node.children[0].children[0].children[0] # convolve with stride of 2 (scale 1 wavelet) # Then convolve again, j2 - j1 is 1 so add a stride of 2, decimate # wav by 2 # to adjust bandwidth to decimated signal # Then convolve again j3 - j2 is 1 so add a stride of 2, decimate # wave by 4, # we've decimated twice by 2 previously w, c = deci.resolve_scales(first_layer) self.assertTrue(w == 4 and c == 2) deci = DefaultDecimation(oversampling=1) # convolve with stride of 1 (scale 1 wavelet (oversampling)) # Then convolve again, j2 - j1 is 1 so add a stride of 2 # to adjust bandwidth to decimated signal # Then convolve again j3 - j2 is 1 so add a stride of 2, decimate # wave by 2, # we've decimated once by 2 previously w, c = deci.resolve_scales(first_layer) self.assertTrue(w == 2 and c == 2) deci = DefaultDecimation(oversampling=1) first_layer = root_node.children[1].children[0].children[0] # convolve with stride of 2 (scale 2 wavelet (oversampling)) # Then convolve again, j2 - j1 is 1 so add a stride of 2 # to adjust bandwidth to decimated signal # Then convolve again j3 - j2 is 1 so add a stride of 2, decimate # wave by 2, # we've decimated twice by 2 previously w, c = deci.resolve_scales(first_layer) self.assertTrue(w == 4 and c == 2) deci = DefaultDecimation(oversampling=0) first_layer = root_node.children[2].children[0] # convolve with stride of 8 (scale 3 wavelet (oversampling)) # Then convolve again, j2 - j1 is 1 so add a stride of 2 w, c = deci.resolve_scales(first_layer) self.assertTrue(w == 8 and c == 2)
def test_apply_father_wavelet_results(self): # Complete the scattering transform with the father wavelet father_wavelet = vanilla_gabor_2d(self.sample_rate, j=self.J) wavelets = [ vanilla_morlet_2d(self.sample_rate, j=i) for i in range(0, self.J) ] apply_conv = ApplyFatherWavlet2D( J=self.J, overlap_log_2=self.overlap_log_2, img_size=self.img.shape, sample_rate=self.sample_rate, wavelet=father_wavelet, ) deci = NoDecimation() # DefaultDecimation(oversampling=oversampling) # input inp = Input(shape=self.img.shape) # valid padding cascade2d = Cascade2D("none", 0, decimation=deci, angles=self.angles) # Pad the input pad_2d = Pad2D(wavelets, decimation=deci) padded = pad_2d.pad(inp) # Apply cascade with successive decimation cascade_tree = CascadeTree(padded, order=self.order) cascade_tree.generate(wavelets, cascade2d._convolve) convs = cascade_tree.get_convolutions() # Create layers to remove padding cascade_tree = CascadeTree(padded, order=self.order) cascade_tree.generate(wavelets, pad_2d._unpad_same) unpad = cascade_tree.get_convolutions() # Remove the padding unpadded_convs = [i[1](i[0]) for i in zip(convs, unpad)] sca_transf = apply_conv.convolve(unpadded_convs) model = Model(inputs=inp, outputs=sca_transf) result = model.predict(np.expand_dims(self.img, axis=0)) # get cnn result, note we're using the third angle for this (index 2) cnn_result1 = result[0][0, 0, 0, 0] cnn_result2 = result[3][0, 0, 0, 1] cnn_result3 = result[6][0, 0, 0, 5] # use all close to assert relative error: manual = np.array( [self.manual_result1, self.manual_result2, self.manual_result3]) manual[1:] /= manual[0] cnn_result = np.array([cnn_result1, cnn_result2, cnn_result3]) cnn_result[1:] /= cnn_result[0] np.testing.assert_allclose( manual, cnn_result, atol=1e-3, err_msg="first order does not match with cnn result.", )
def build_model_1d(N, J, order, oversampling=2, conv_padding="valid", concatenate=False, apply_father_wavelet=True): """ Build the 1-d wavelet transform shape - tuple, length of the 1-d data (maybe mulitvariate) J - int, the scale of the transform. order - int, the order of the transform. oversampling - int, 2**oversampling delays decimation of the series. conv_padding - "valid" or "same" supplied to (see Keras' Conv1D) "valid" should be faster. concatenate - True/False, whether or not to concatenate the output. apply_father_wavelet - True/False, the final step in the transform is to convolve each output """ sample_rate = 1.0 inp = Input(shape=(N,1)) wavelets = [vanilla_morlet_1d(sample_rate, J, j=i) for i in range(0,J)] calibrate_wavelets_1d(wavelets) father_wavelet = vanilla_gabor_1d(sample_rate, J) deci = DefaultDecimation(oversampling=oversampling) # pad pad_1d = Pad1D(wavelets, decimation=deci, conv_padding=conv_padding) padded = pad_1d.pad(inp) cascade_tree = CascadeTree(padded, order=order) cascade = Cascade1D(decimation=deci, _padding=conv_padding) convs = cascade.transform(cascade_tree, wavelets=wavelets) # Create layers to remove padding cascade_tree = CascadeTree(padded, order=order) cascade_tree.generate(wavelets, pad_1d.unpad) unpad = cascade_tree.get_convolutions() # Remove the padding unpadded_convs = [i[1](i[0]) for i in zip(convs, unpad)] if apply_father_wavelet: appl = ApplyFatherWavlet1D(wavelet=father_wavelet, J=J, img_size=(N,), sample_rate=sample_rate) sca_transf = appl.convolve(unpadded_convs) if concatenate: sca_transf = Concatenate()(sca_transf) # implement scattering transform. model = Model(inputs=inp, outputs=sca_transf) else: if concatenate: unpadded_convs = [apply_upsampling(i,N) for i in unpadded_convs] unpadded_convs = Concatenate()(unpadded_convs) model = Model(inputs=inp, outputs=unpadded_convs) return model
def vanilla_scattering_transform(J, img_size, sample_rate, overlap_log_2=0, order=2, oversampling=1, num_angles=8, do_father=True): # to reproduce scatnet.m and kymatio definitions (see NOTICE.txt) angles = tuple([ 90.0 - np.rad2deg((int(num_angles - num_angles / 2 - 1) - theta) * np.pi / num_angles) for theta in range(num_angles) ]) # vanilla filter bank wavelets = [vanilla_morlet_2d(sample_rate, j=i) for i in range(0, J)] father_wavelet = vanilla_gabor_2d(sample_rate, j=J) # method of decimation deci = DefaultDecimation(oversampling=oversampling) # input inp = Input(shape=img_size) # valid padding cascade2d = Cascade2D("none", 0, decimation=deci, angles=angles) # Pad the input pad_2d = Pad2D(wavelets, decimation=deci) padded = pad_2d.pad(inp) # Apply cascade with successive decimation cascade_tree = CascadeTree(padded, order=order) cascade_tree.generate(wavelets, cascade2d._convolve) convs = cascade_tree.get_convolutions() # Create layers to remove padding cascade_tree = CascadeTree(padded, order=order) cascade_tree.generate(wavelets, pad_2d._unpad_same) unpad = cascade_tree.get_convolutions() # Remove the padding unpadded_convs = [i[1](i[0]) for i in zip(convs, unpad)] # Complete the scattering transform with the father wavelet if do_father: apply_conv = ApplyFatherWavlet2D( J=J - 1, overlap_log_2=overlap_log_2, img_size=img_size, sample_rate=sample_rate, wavelet=father_wavelet, ) sca_transf = apply_conv.convolve(unpadded_convs) model = Model(inputs=inp, outputs=sca_transf) else: model = Model(inputs=inp, outputs=unpadded_convs) # generate visuals too, it's a factory cascade_tree = CascadeTree(padded, order=order) cascade_tree.generate(wavelets, cascade2d._convolve) root_element = PlotElement(name=cascade_tree.root_node.name) root_element.radius_range = (0.02, 1) root_element.angle_range = (0, 180) viz = Visualize2D(angles=angles) viz.recurse(cascade_tree.root_node, root_element, max_order=order) return model, viz