def __init__(self, J, Q, audio_length): super(Scatter, self).__init__() self.J = J self.Q = Q self.T = audio_length self.meta = Scattering1D.compute_meta_scattering(self.J, self.Q) self.order0_indices = (self.meta['order'] == 0) self.order1_indices = (self.meta['order'] == 1) self.order2_indices = (self.meta['order'] == 2) self.scattering = Scattering1D(self.J, self.T, self.Q).cuda() self.output_size = self.scattering.output_size()
def normalised_frquency_vector(x, J, Q, epsilon_order_1=1 * 10**-6, epsilon_order_2=1 * 10**-6): x = torch.from_numpy(x).float() x /= x.abs().max() x = x.view(1, -1) T = x.shape[-1] scattering = Scattering1D(J, T, Q=Q, average=True, oversampling=0, vectorize=True) Sx = scattering.forward(x) Sx_abs = scattering.forward(np.abs(x)) meta = Scattering1D.compute_meta_scattering(J, Q) order0 = (meta['order'] == 0) order1 = (meta['order'] == 1) order2 = (meta['order'] == 2) Sx1 = normalise_order1(Sx, Sx_abs, order0, order1, order2, epsilon_order_1, frequency_normalisation_order_1_vector=[]) Sx2 = normalise_order2(J, Q, Sx, order0, order1, order2, epsilon_order_2, frequency_normalisation_order_2_vector=[]) return np.mean(scale_value(Sx1.numpy()), axis=1), np.mean(scale_value(Sx2.numpy()), axis=1)
def plot_multi_order_scattering(x, J, Q, order1_frequency_axis=[], normalise_1=False, normalise_2=False, epsilon_order_1=1 * 10**-6, epsilon_order_2=1 * 10**-6, frequency_normalisation_order_1_vector=None, frequency_normalisation_order_2_vector=None): x = torch.from_numpy(x).float() x /= x.abs().max() x = x.view(1, -1) T = x.shape[-1] scattering = Scattering1D(J, T, Q=Q, average=True, oversampling=0, vectorize=True) Sx = scattering.forward(x) Sx_abs = scattering.forward(np.abs(x)) meta = Scattering1D.compute_meta_scattering(J, Q) order0 = (meta['order'] == 0) order1 = (meta['order'] == 1) order2 = (meta['order'] == 2) fig = make_subplots( rows=3, cols=6, column_widths=[0.4, 0.4, 0.4, 0.4, 0.4, 0.4], row_heights=[0.2, 0.2, 0.2], specs=[[{ "type": "Scatter" }, { "type": "Heatmap" }, { "type": "Heatmap" }, { "type": "Heatmap" }, { "type": "Heatmap" }, { "type": "Heatmap" }], [{ "type": "Scatter" }, { "type": "Scatter" }, { "type": "Scatter" }, { "type": "Scatter" }, { "type": "Scatter" }, { "type": "Scatter" }], [ None, { "type": "Scatter" }, { "type": "Scatter" }, { "type": "Scatter" }, { "type": "Scatter" }, { "type": "Scatter" } ]], subplot_titles=( 'Temporal signal', 'Scattering Order 1', 'Scattering Order 1 Normalised', 'Scattering Order 2', 'Scattering Order 2 Normalised', 'Order 2 Frequency', "Scattering Order 0", 'Scattering Order 1 mean', 'Scattering Order 1 mean Normalised', 'Scattering Order 2 mean', 'Scattering Order 2 mean Normalised', 'Order 2 Frequency mean', None, 'Scattering Order 1 max', 'Scattering Order 1 max Normalised', 'Scattering Order 2 max', 'Scattering Order 2 max Normalised', 'Order 2 Frequency max')) fig.add_trace(go.Scatter(y=x[0, :].numpy(), name="Negative"), row=1, col=1) fig.add_trace(go.Scatter(y=Sx[0, order0, :].numpy().ravel(), name="Negative"), row=2, col=1) if normalise_1: Sx1 = normalise_order1(Sx, Sx_abs, order0, order1, order2, epsilon_order_1, frequency_normalisation_order_1_vector= frequency_normalisation_order_1_vector) else: Sx1 = Sx[0, order1, :] if (len(order1_frequency_axis) != 0): fig.add_trace(go.Heatmap(z=scale_value(Sx[0, order1, :].numpy()), y=order1_frequency_axis, colorscale='Viridis', showscale=False), row=1, col=2) fig.add_trace(go.Heatmap(z=scale_value(Sx1.numpy()), y=order1_frequency_axis, colorscale='Viridis', showscale=False), row=1, col=3) else: fig.add_trace(go.Heatmap(z=scale_value(Sx[0, order1, :].numpy()), colorscale='Viridis', showscale=False), row=1, col=2) fig.update_yaxes(autorange="reversed", row=1, col=2) fig.add_trace(go.Heatmap(z=scale_value(Sx1.numpy()), colorscale='Viridis', showscale=False), row=1, col=3) fig.update_yaxes(autorange="reversed", row=1, col=3) fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx[0, order1, :].numpy()), axis=1), name="Negative"), row=2, col=2) fig.add_trace(go.Scatter(y=np.max(scale_value(Sx[0, order1, :].numpy()), axis=1), name="Negative"), row=3, col=2) fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx1.numpy()), axis=1), name="Negative"), row=2, col=3) fig.add_trace(go.Scatter(y=np.max(scale_value(Sx1.numpy()), axis=1), name="Negative"), row=3, col=3) if normalise_2: Sx2 = normalise_order2(J, Q, Sx, order0, order1, order2, epsilon_order_2, frequency_normalisation_order_2_vector= frequency_normalisation_order_2_vector) else: Sx2 = Sx[0, order2, :] fig.add_trace(go.Heatmap(z=scale_value(Sx[0, order2, :].numpy()), colorscale='Viridis', showscale=False), row=1, col=4) fig.update_yaxes(autorange="reversed", row=1, col=4) fig.add_trace(go.Heatmap(z=scale_value(Sx2.numpy()), colorscale='Viridis', showscale=False), row=1, col=5) fig.update_yaxes(autorange="reversed", row=1, col=5) fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx[0, order2, :].numpy()), axis=1), name="Negative"), row=2, col=4) fig.add_trace(go.Scatter(y=np.max(scale_value(Sx[0, order2, :].numpy()), axis=1), name="Negative"), row=3, col=4) fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx2.numpy()), axis=1), name="Negative"), row=2, col=5) fig.add_trace(go.Scatter(y=np.max(scale_value(Sx2.numpy()), axis=1), name="Negative"), row=3, col=5) fig.update_layout(showlegend=False) Sx2_Bis = select_frequency(Sx2, T, J, Q, index_frequency=None) fig.add_trace(go.Heatmap(z=scale_value(Sx2_Bis), colorscale='Viridis', showscale=False), row=1, col=6) fig.update_yaxes(autorange="reversed", row=1, col=6) fig.add_trace(go.Scatter(y=np.mean(scale_value(Sx2_Bis), axis=1), name="Negative"), row=2, col=6) fig.add_trace(go.Scatter(y=np.max(scale_value(Sx2_Bis), axis=1), name="Negative"), row=3, col=6) fig.show()
# ----------- # Let's take a look at the signal spectrogram plt.figure(figsize=(10, 10)) plt.specgram(x.numpy().ravel(), Fs=1024) plt.title("Time-Frequency spectrogram of signal") ############################################################################### # Doing the scattering transform # ------------------------------ J = 6 Q = 16 scattering = Scattering1D(T, J, Q) # get the metadata on the coordinates of the scattering meta = Scattering1D.compute_meta_scattering(J, Q) order0 = (meta['order'] == 0) order1 = (meta['order'] == 1) order2 = (meta['order'] == 2) s = scattering.forward(x)[0] plt.figure(figsize=(10, 10), dpi=300) plt.subplot(3, 1, 1) plt.plot(s[order0].numpy()) plt.title("Scattering order 0") plt.subplot(3, 1, 2) plt.imshow(s[order1].numpy(), aspect='auto') plt.title("Scattering order 1") plt.subplot(3, 1, 3) plt.imshow(s[order2].numpy(), aspect='auto') plt.title("Scattering order 2")