Example #1
0
    def test_unpad_1(self):
        from blusky.transforms.cascade_tree import CascadeTree

        inp = Input(shape=(99, 99, 1))
        padded = self.pad_2d.pad(inp)

        cascade_tree = CascadeTree(padded, order=2)
        cascade_tree.generate(self.wavelets, self.cascade._convolve)
        convs = cascade_tree.get_convolutions()

        cascade_tree = CascadeTree(padded, order=2)
        cascade_tree.generate(self.wavelets, self.pad_2d.unpad)
        unpadded = cascade_tree.get_convolutions()

        unpadded_convs = [i[1](i[0]) for i in zip(convs, unpadded)]

        # The reason the shape is changing is because its being decimated.

        self.assertTrue(unpadded_convs[0].shape[1] == 99)
        self.assertTrue(unpadded_convs[0].shape[2] == 99)

        self.assertTrue(unpadded_convs[1].shape[1] == 50)
        self.assertTrue(unpadded_convs[1].shape[2] == 50)

        self.assertTrue(unpadded_convs[2].shape[1] == 25)
        self.assertTrue(unpadded_convs[2].shape[2] == 25)

        self.assertTrue(unpadded_convs[3].shape[1] == 50)
        self.assertTrue(unpadded_convs[3].shape[2] == 50)

        self.assertTrue(unpadded_convs[4].shape[1] == 25)
        self.assertTrue(unpadded_convs[4].shape[2] == 25)

        self.assertTrue(unpadded_convs[5].shape[1] == 25)
        self.assertTrue(unpadded_convs[5].shape[2] == 25)
Example #2
0
    def test_cascade_1d_results(self):
        # vanilla filter bank
        wavelets = [
            vanilla_morlet_1d(self.sample_rate, self.J, j=i)
            for i in range(0, self.J)
        ]

        deci = NoDecimation()
        inp = Input(shape=(self.N, 1))

        # pad
        pad_1d = Pad1D(wavelets, decimation=deci)
        padded = pad_1d.pad(inp)

        #
        cascade_tree = CascadeTree(padded, order=self.order)
        cascade = Cascade1D(decimation=deci)
        convs = cascade.transform(cascade_tree, wavelets=wavelets)

        # Create layers to remove padding
        cascade_tree = CascadeTree(padded, order=self.order)
        cascade_tree.generate(wavelets, pad_1d._unpad_same)
        unpad = cascade_tree.get_convolutions()

        # Remove the padding
        unpadded_convs = [i[1](i[0]) for i in zip(convs, unpad)]

        model = Model(inputs=inp, outputs=unpadded_convs)

        result = model.predict(np.expand_dims(self.ts, axis=0))

        cnn_result_1 = np.squeeze(result[0])
        cnn_result_2 = np.squeeze(result[5])
        cnn_result_3 = np.squeeze(result[-10])

        np.testing.assert_allclose(
            self.conv1,
            cnn_result_1,
            atol=1e-3,
            err_msg="first order does not match with cnn result.",
        )

        np.testing.assert_allclose(
            self.conv2,
            cnn_result_2,
            atol=1e-3,
            err_msg="first order does not match with cnn result.",
        )

        np.testing.assert_allclose(
            self.conv3,
            cnn_result_3,
            atol=1e-3,
            err_msg="first order does not match with cnn result.",
        )
Example #3
0
    def test_unpad_2(self):
        from blusky.transforms.cascade_tree import CascadeTree

        # --- Pad/Unpad --- #
        inp = Input(shape=(99, 99, 1))
        padded = self.pad_2d.pad(inp)

        cascade_tree = CascadeTree(padded, order=2)
        cascade_tree.generate(self.wavelets, self.cascade._convolve)
        convs = cascade_tree.get_convolutions()

        cascade_tree = CascadeTree(padded, order=2)
        cascade_tree.generate(self.wavelets, self.pad_2d.unpad)
        unpadded = cascade_tree.get_convolutions()

        unpadded_convs = [i[1](i[0]) for i in zip(convs, unpadded)]

        model_1 = Model(inputs=inp, outputs=unpadded_convs)

        result_1 = model_1.predict(self.imgs[:1])

        # --- Manually --- #
        inp = Input(shape=(131, 131, 1))

        cascade_tree = CascadeTree(inp, order=2)
        cascade_tree.generate(self.wavelets, self.cascade._convolve)
        convs = cascade_tree.get_convolutions()

        model_2 = Model(inputs=inp, outputs=convs)

        padded = np.pad(
            np.array(self.imgs[:1]),
            ((0, 0), (16, 16), (16, 16), (0, 0)),
            "reflect",
        )

        result_2 = model_2.predict(padded)

        result_2[0] = result_2[0][:, 16:-16, 16:-16, :]

        np.allclose(result_1[0], result_2[0])
Example #4
0
    def test_default_decimation(self):
        """ test the default decimation method.

            At the first layer we decimate according to scale.
            At subsequent layers, we decimate according to the difference
            in order with the current and previuous layers.
        """
        import numpy as np

        class test_wav(HasStrictTraits):
            scale = Int(1)

        cascade_tree = CascadeTree(Input(shape=(64, 64, 1)), order=4)
        cascade_tree.generate(
            [
                test_wav(scale=1),
                test_wav(scale=2),
                test_wav(scale=3),
                test_wav(scale=4),
            ],
            lambda x, y, z: x,
        )

        # go down a path in the tree
        root_node = cascade_tree.root_node
        first_layer = root_node.children[0]
        second_layer = first_layer.children[0]
        third_layer = second_layer.children[0]

        deci = DefaultDecimation(oversampling=0)
        wav = np.empty((64, 64))
        wavp = deci.decimate_wavelet(wav, 2)
        np.array_equal(wav[::2, ::2], wavp)

        # if I would decimate wavelet/conv
        deci = DefaultDecimation(oversampling=0)

        w, c = deci.resolve_scales(root_node)
        self.assertTrue(w == 1 and c == 1)
        w, c = deci.resolve_scales(first_layer)
        self.assertTrue(w == 1 and c == 2)
        w, c = deci.resolve_scales(second_layer)
        self.assertTrue(w == 2 and c == 2)
        w, c = deci.resolve_scales(third_layer)
        self.assertTrue(w == 4 and c == 2)

        #
        deci = DefaultDecimation(oversampling=1)

        w, c = deci.resolve_scales(root_node)
        self.assertTrue(w == 1 and c == 1)
        w, c = deci.resolve_scales(first_layer)
        self.assertTrue(w == 1 and c == 1)
        w, c = deci.resolve_scales(second_layer)
        self.assertTrue(w == 1 and c == 2)
        w, c = deci.resolve_scales(third_layer)
        self.assertTrue(w == 2 and c == 2)

        #
        deci = DefaultDecimation(oversampling=2)

        w, c = deci.resolve_scales(root_node)
        self.assertTrue(w == 1 and c == 1)
        w, c = deci.resolve_scales(first_layer)
        self.assertTrue(w == 1 and c == 1)
        w, c = deci.resolve_scales(second_layer)
        self.assertTrue(w == 1 and c == 1)
        w, c = deci.resolve_scales(third_layer)
        self.assertTrue(w == 1 and c == 2)

        deci = DefaultDecimation(oversampling=0)

        root_node = cascade_tree.root_node
        first_layer = root_node.children[0].children[0].children[0]

        # convolve with stride of 2 (scale 1 wavelet)
        # Then convolve again, j2 - j1 is 1 so add a stride of 2, decimate
        # wav by 2
        # to adjust bandwidth to decimated signal
        # Then convolve again  j3 - j2 is 1 so add a stride of 2, decimate
        # wave by 4,
        # we've decimated twice by 2 previously
        w, c = deci.resolve_scales(first_layer)
        self.assertTrue(w == 4 and c == 2)

        deci = DefaultDecimation(oversampling=1)
        # convolve with stride of 1 (scale 1 wavelet (oversampling))
        # Then convolve again, j2 - j1 is 1 so add a stride of 2
        # to adjust bandwidth to decimated signal
        # Then convolve again  j3 - j2 is 1 so add a stride of 2, decimate
        # wave by 2,
        # we've decimated once by 2 previously
        w, c = deci.resolve_scales(first_layer)
        self.assertTrue(w == 2 and c == 2)

        deci = DefaultDecimation(oversampling=1)
        first_layer = root_node.children[1].children[0].children[0]

        # convolve with stride of 2 (scale 2 wavelet (oversampling))
        # Then convolve again, j2 - j1 is 1 so add a stride of 2
        # to adjust bandwidth to decimated signal
        # Then convolve again  j3 - j2 is 1 so add a stride of 2, decimate
        # wave by 2,
        # we've decimated twice by 2 previously
        w, c = deci.resolve_scales(first_layer)
        self.assertTrue(w == 4 and c == 2)

        deci = DefaultDecimation(oversampling=0)
        first_layer = root_node.children[2].children[0]
        # convolve with stride of 8 (scale 3 wavelet (oversampling))
        # Then convolve again, j2 - j1 is 1 so add a stride of 2
        w, c = deci.resolve_scales(first_layer)
        self.assertTrue(w == 8 and c == 2)
Example #5
0
    def test_apply_father_wavelet_results(self):
        # Complete the scattering transform with the father wavelet
        father_wavelet = vanilla_gabor_2d(self.sample_rate, j=self.J)
        wavelets = [
            vanilla_morlet_2d(self.sample_rate, j=i) for i in range(0, self.J)
        ]

        apply_conv = ApplyFatherWavlet2D(
            J=self.J,
            overlap_log_2=self.overlap_log_2,
            img_size=self.img.shape,
            sample_rate=self.sample_rate,
            wavelet=father_wavelet,
        )

        deci = NoDecimation()  # DefaultDecimation(oversampling=oversampling)

        # input
        inp = Input(shape=self.img.shape)

        # valid padding
        cascade2d = Cascade2D("none", 0, decimation=deci, angles=self.angles)

        # Pad the input
        pad_2d = Pad2D(wavelets, decimation=deci)
        padded = pad_2d.pad(inp)

        # Apply cascade with successive decimation
        cascade_tree = CascadeTree(padded, order=self.order)
        cascade_tree.generate(wavelets, cascade2d._convolve)
        convs = cascade_tree.get_convolutions()

        # Create layers to remove padding
        cascade_tree = CascadeTree(padded, order=self.order)
        cascade_tree.generate(wavelets, pad_2d._unpad_same)
        unpad = cascade_tree.get_convolutions()

        # Remove the padding
        unpadded_convs = [i[1](i[0]) for i in zip(convs, unpad)]

        sca_transf = apply_conv.convolve(unpadded_convs)

        model = Model(inputs=inp, outputs=sca_transf)

        result = model.predict(np.expand_dims(self.img, axis=0))

        # get cnn result, note we're using the third angle for this (index 2)
        cnn_result1 = result[0][0, 0, 0, 0]
        cnn_result2 = result[3][0, 0, 0, 1]
        cnn_result3 = result[6][0, 0, 0, 5]

        # use all close to assert relative error:
        manual = np.array(
            [self.manual_result1, self.manual_result2, self.manual_result3])
        manual[1:] /= manual[0]

        cnn_result = np.array([cnn_result1, cnn_result2, cnn_result3])
        cnn_result[1:] /= cnn_result[0]

        np.testing.assert_allclose(
            manual,
            cnn_result,
            atol=1e-3,
            err_msg="first order does not match with cnn result.",
        )
Example #6
0
def build_model_1d(N, J, order,
                   oversampling=2, 
                   conv_padding="valid", 
                   concatenate=False,
                   apply_father_wavelet=True):
    """ Build the 1-d wavelet transform
    shape - tuple, length of the 1-d data (maybe mulitvariate)
    J - int, the scale of the transform.
    order - int, the order of the transform.
    oversampling - int, 2**oversampling 
            delays decimation of the series.
    conv_padding - "valid" or "same" supplied to (see Keras' Conv1D)
            "valid" should be faster.
    concatenate - True/False, whether or not to concatenate the output.
    apply_father_wavelet - True/False, the final step in the transform is to convolve each output
    """
    sample_rate = 1.0
    
    inp = Input(shape=(N,1))
    
    wavelets = [vanilla_morlet_1d(sample_rate, J, j=i) 
                        for i in range(0,J)]
    calibrate_wavelets_1d(wavelets)

    father_wavelet = vanilla_gabor_1d(sample_rate, J)

    deci = DefaultDecimation(oversampling=oversampling)

    # pad
    pad_1d = Pad1D(wavelets, decimation=deci, 
                   conv_padding=conv_padding)
    padded = pad_1d.pad(inp)

    cascade_tree = CascadeTree(padded, order=order)

    cascade = Cascade1D(decimation=deci, _padding=conv_padding)
    convs = cascade.transform(cascade_tree, wavelets=wavelets)

    # Create layers to remove padding
    cascade_tree = CascadeTree(padded, order=order)
    cascade_tree.generate(wavelets, pad_1d.unpad)
    unpad = cascade_tree.get_convolutions()

    # Remove the padding
    unpadded_convs = [i[1](i[0]) for i in zip(convs, unpad)]

    if apply_father_wavelet:
        appl = ApplyFatherWavlet1D(wavelet=father_wavelet, 
                                   J=J, 
                                   img_size=(N,), 
                                   sample_rate=sample_rate)    

        sca_transf = appl.convolve(unpadded_convs)
    
        if concatenate:
            sca_transf = Concatenate()(sca_transf)
    
        # implement scattering transform.
        model = Model(inputs=inp, outputs=sca_transf)
    else:
        if concatenate:
            unpadded_convs = [apply_upsampling(i,N) for i in unpadded_convs]
            unpadded_convs = Concatenate()(unpadded_convs)
        
        model = Model(inputs=inp, outputs=unpadded_convs)
    
    return model
Example #7
0
def vanilla_scattering_transform(J,
                                 img_size,
                                 sample_rate,
                                 overlap_log_2=0,
                                 order=2,
                                 oversampling=1,
                                 num_angles=8,
                                 do_father=True):

    # to reproduce scatnet.m and kymatio definitions (see NOTICE.txt)
    angles = tuple([
        90.0 - np.rad2deg((int(num_angles - num_angles / 2 - 1) - theta) *
                          np.pi / num_angles) for theta in range(num_angles)
    ])

    # vanilla filter bank
    wavelets = [vanilla_morlet_2d(sample_rate, j=i) for i in range(0, J)]
    father_wavelet = vanilla_gabor_2d(sample_rate, j=J)

    # method of decimation
    deci = DefaultDecimation(oversampling=oversampling)

    # input
    inp = Input(shape=img_size)

    # valid padding
    cascade2d = Cascade2D("none", 0, decimation=deci, angles=angles)

    # Pad the input
    pad_2d = Pad2D(wavelets, decimation=deci)
    padded = pad_2d.pad(inp)

    # Apply cascade with successive decimation
    cascade_tree = CascadeTree(padded, order=order)
    cascade_tree.generate(wavelets, cascade2d._convolve)
    convs = cascade_tree.get_convolutions()

    # Create layers to remove padding
    cascade_tree = CascadeTree(padded, order=order)
    cascade_tree.generate(wavelets, pad_2d._unpad_same)
    unpad = cascade_tree.get_convolutions()

    # Remove the padding
    unpadded_convs = [i[1](i[0]) for i in zip(convs, unpad)]

    # Complete the scattering transform with the father wavelet
    if do_father:
        apply_conv = ApplyFatherWavlet2D(
            J=J - 1,
            overlap_log_2=overlap_log_2,
            img_size=img_size,
            sample_rate=sample_rate,
            wavelet=father_wavelet,
        )
        sca_transf = apply_conv.convolve(unpadded_convs)
        model = Model(inputs=inp, outputs=sca_transf)

    else:
        model = Model(inputs=inp, outputs=unpadded_convs)

    # generate visuals too, it's a factory
    cascade_tree = CascadeTree(padded, order=order)
    cascade_tree.generate(wavelets, cascade2d._convolve)

    root_element = PlotElement(name=cascade_tree.root_node.name)
    root_element.radius_range = (0.02, 1)
    root_element.angle_range = (0, 180)

    viz = Visualize2D(angles=angles)
    viz.recurse(cascade_tree.root_node, root_element, max_order=order)

    return model, viz