Example #1
0
 def __init__(self, J, Q, audio_length):
     
     super(Scatter, self).__init__()
     
     self.J = J
     self.Q = Q
     self.T = audio_length
     self.meta = Scattering1D.compute_meta_scattering(self.J, self.Q)
                           
     self.order0_indices = (self.meta['order'] == 0)
     self.order1_indices = (self.meta['order'] == 1)
     self.order2_indices = (self.meta['order'] == 2)
     
     self.scattering = Scattering1D(self.J, self.T, self.Q).cuda()
     self.output_size = self.scattering.output_size()
Example #2
0
def normalised_frquency_vector(x,
                               J,
                               Q,
                               epsilon_order_1=1 * 10**-6,
                               epsilon_order_2=1 * 10**-6):
    x = torch.from_numpy(x).float()
    x /= x.abs().max()
    x = x.view(1, -1)
    T = x.shape[-1]
    scattering = Scattering1D(J,
                              T,
                              Q=Q,
                              average=True,
                              oversampling=0,
                              vectorize=True)
    Sx = scattering.forward(x)
    Sx_abs = scattering.forward(np.abs(x))
    meta = Scattering1D.compute_meta_scattering(J, Q)
    order0 = (meta['order'] == 0)
    order1 = (meta['order'] == 1)
    order2 = (meta['order'] == 2)

    Sx1 = normalise_order1(Sx,
                           Sx_abs,
                           order0,
                           order1,
                           order2,
                           epsilon_order_1,
                           frequency_normalisation_order_1_vector=[])
    Sx2 = normalise_order2(J,
                           Q,
                           Sx,
                           order0,
                           order1,
                           order2,
                           epsilon_order_2,
                           frequency_normalisation_order_2_vector=[])
    return np.mean(scale_value(Sx1.numpy()),
                   axis=1), np.mean(scale_value(Sx2.numpy()), axis=1)
Example #3
0
def plot_multi_order_scattering(x,
                                J,
                                Q,
                                order1_frequency_axis=[],
                                normalise_1=False,
                                normalise_2=False,
                                epsilon_order_1=1 * 10**-6,
                                epsilon_order_2=1 * 10**-6,
                                frequency_normalisation_order_1_vector=None,
                                frequency_normalisation_order_2_vector=None):

    x = torch.from_numpy(x).float()
    x /= x.abs().max()
    x = x.view(1, -1)

    T = x.shape[-1]
    scattering = Scattering1D(J,
                              T,
                              Q=Q,
                              average=True,
                              oversampling=0,
                              vectorize=True)
    Sx = scattering.forward(x)
    Sx_abs = scattering.forward(np.abs(x))
    meta = Scattering1D.compute_meta_scattering(J, Q)
    order0 = (meta['order'] == 0)
    order1 = (meta['order'] == 1)
    order2 = (meta['order'] == 2)

    fig = make_subplots(
        rows=3,
        cols=6,
        column_widths=[0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
        row_heights=[0.2, 0.2, 0.2],
        specs=[[{
            "type": "Scatter"
        }, {
            "type": "Heatmap"
        }, {
            "type": "Heatmap"
        }, {
            "type": "Heatmap"
        }, {
            "type": "Heatmap"
        }, {
            "type": "Heatmap"
        }],
               [{
                   "type": "Scatter"
               }, {
                   "type": "Scatter"
               }, {
                   "type": "Scatter"
               }, {
                   "type": "Scatter"
               }, {
                   "type": "Scatter"
               }, {
                   "type": "Scatter"
               }],
               [
                   None, {
                       "type": "Scatter"
                   }, {
                       "type": "Scatter"
                   }, {
                       "type": "Scatter"
                   }, {
                       "type": "Scatter"
                   }, {
                       "type": "Scatter"
                   }
               ]],
        subplot_titles=(
            'Temporal signal', 'Scattering Order 1',
            'Scattering Order 1 Normalised', 'Scattering Order 2',
            'Scattering Order 2 Normalised', 'Order 2 Frequency',
            "Scattering Order 0", 'Scattering Order 1 mean',
            'Scattering Order 1 mean Normalised', 'Scattering Order 2 mean',
            'Scattering Order 2 mean Normalised', 'Order 2 Frequency mean',
            None, 'Scattering Order 1 max',
            'Scattering Order 1 max Normalised', 'Scattering Order 2 max',
            'Scattering Order 2 max Normalised', 'Order 2 Frequency max'))

    fig.add_trace(go.Scatter(y=x[0, :].numpy(), name="Negative"), row=1, col=1)
    fig.add_trace(go.Scatter(y=Sx[0, order0, :].numpy().ravel(),
                             name="Negative"),
                  row=2,
                  col=1)

    if normalise_1:
        Sx1 = normalise_order1(Sx,
                               Sx_abs,
                               order0,
                               order1,
                               order2,
                               epsilon_order_1,
                               frequency_normalisation_order_1_vector=
                               frequency_normalisation_order_1_vector)
    else:
        Sx1 = Sx[0, order1, :]

    if (len(order1_frequency_axis) != 0):
        fig.add_trace(go.Heatmap(z=scale_value(Sx[0, order1, :].numpy()),
                                 y=order1_frequency_axis,
                                 colorscale='Viridis',
                                 showscale=False),
                      row=1,
                      col=2)
        fig.add_trace(go.Heatmap(z=scale_value(Sx1.numpy()),
                                 y=order1_frequency_axis,
                                 colorscale='Viridis',
                                 showscale=False),
                      row=1,
                      col=3)
    else:
        fig.add_trace(go.Heatmap(z=scale_value(Sx[0, order1, :].numpy()),
                                 colorscale='Viridis',
                                 showscale=False),
                      row=1,
                      col=2)
        fig.update_yaxes(autorange="reversed", row=1, col=2)
        fig.add_trace(go.Heatmap(z=scale_value(Sx1.numpy()),
                                 colorscale='Viridis',
                                 showscale=False),
                      row=1,
                      col=3)
        fig.update_yaxes(autorange="reversed", row=1, col=3)
    fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx[0, order1, :].numpy()),
                                       axis=1),
                             name="Negative"),
                  row=2,
                  col=2)
    fig.add_trace(go.Scatter(y=np.max(scale_value(Sx[0, order1, :].numpy()),
                                      axis=1),
                             name="Negative"),
                  row=3,
                  col=2)
    fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx1.numpy()), axis=1),
                             name="Negative"),
                  row=2,
                  col=3)
    fig.add_trace(go.Scatter(y=np.max(scale_value(Sx1.numpy()), axis=1),
                             name="Negative"),
                  row=3,
                  col=3)

    if normalise_2:
        Sx2 = normalise_order2(J,
                               Q,
                               Sx,
                               order0,
                               order1,
                               order2,
                               epsilon_order_2,
                               frequency_normalisation_order_2_vector=
                               frequency_normalisation_order_2_vector)
    else:
        Sx2 = Sx[0, order2, :]

    fig.add_trace(go.Heatmap(z=scale_value(Sx[0, order2, :].numpy()),
                             colorscale='Viridis',
                             showscale=False),
                  row=1,
                  col=4)
    fig.update_yaxes(autorange="reversed", row=1, col=4)
    fig.add_trace(go.Heatmap(z=scale_value(Sx2.numpy()),
                             colorscale='Viridis',
                             showscale=False),
                  row=1,
                  col=5)
    fig.update_yaxes(autorange="reversed", row=1, col=5)
    fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx[0, order2, :].numpy()),
                                       axis=1),
                             name="Negative"),
                  row=2,
                  col=4)
    fig.add_trace(go.Scatter(y=np.max(scale_value(Sx[0, order2, :].numpy()),
                                      axis=1),
                             name="Negative"),
                  row=3,
                  col=4)
    fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx2.numpy()), axis=1),
                             name="Negative"),
                  row=2,
                  col=5)
    fig.add_trace(go.Scatter(y=np.max(scale_value(Sx2.numpy()), axis=1),
                             name="Negative"),
                  row=3,
                  col=5)
    fig.update_layout(showlegend=False)

    Sx2_Bis = select_frequency(Sx2, T, J, Q, index_frequency=None)

    fig.add_trace(go.Heatmap(z=scale_value(Sx2_Bis),
                             colorscale='Viridis',
                             showscale=False),
                  row=1,
                  col=6)
    fig.update_yaxes(autorange="reversed", row=1, col=6)
    fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx2_Bis), axis=1),
                             name="Negative"),
                  row=2,
                  col=6)
    fig.add_trace(go.Scatter(y=np.max(scale_value(Sx2_Bis), axis=1),
                             name="Negative"),
                  row=3,
                  col=6)
    fig.show()
Example #4
0
# -----------
# Let's take a look at the signal spectrogram
plt.figure(figsize=(10, 10))
plt.specgram(x.numpy().ravel(), Fs=1024)
plt.title("Time-Frequency spectrogram of signal")

###############################################################################
# Doing the scattering transform
# ------------------------------
J = 6
Q = 16

scattering = Scattering1D(T, J, Q)

# get the metadata on the coordinates of the scattering
meta = Scattering1D.compute_meta_scattering(J, Q)
order0 = (meta['order'] == 0)
order1 = (meta['order'] == 1)
order2 = (meta['order'] == 2)

s = scattering.forward(x)[0]
plt.figure(figsize=(10, 10), dpi=300)
plt.subplot(3, 1, 1)
plt.plot(s[order0].numpy())
plt.title("Scattering order 0")
plt.subplot(3, 1, 2)
plt.imshow(s[order1].numpy(), aspect='auto')
plt.title("Scattering order 1")
plt.subplot(3, 1, 3)
plt.imshow(s[order2].numpy(), aspect='auto')
plt.title("Scattering order 2")