コード例 #1
0
    def test_shift_without_rollover(self):
        samples = np.array([1.0, 0.5, 0.25, 0.125], dtype=np.float32)
        sample_rate = 16000

        forward_augmenter = Compose(
            [Shift(min_fraction=0.5, max_fraction=0.5, rollover=False, p=1.0)])
        forward_shifted_samples = forward_augmenter(samples=samples,
                                                    sample_rate=sample_rate)
        assert_almost_equal(forward_shifted_samples,
                            np.array([0.0, 0.0, 1.0, 0.5], dtype=np.float32))
        self.assertEqual(forward_shifted_samples.dtype, np.float32)
        self.assertEqual(len(forward_shifted_samples), 4)

        backward_augmenter = Compose([
            Shift(min_fraction=-0.25,
                  max_fraction=-0.25,
                  rollover=False,
                  p=1.0)
        ])
        backward_shifted_samples = backward_augmenter(samples=samples,
                                                      sample_rate=sample_rate)
        assert_almost_equal(
            backward_shifted_samples,
            np.array([0.5, 0.25, 0.125, 0.0], dtype=np.float32),
        )
        self.assertEqual(backward_shifted_samples.dtype, np.float32)
        self.assertEqual(len(backward_shifted_samples), 4)
コード例 #2
0
    def test_shift_fade_rollover_3(self):
        samples = np.array(
            [[1.0, 2.0, 3.0, 4.0, 5.0], [-1.0, -2.0, -3.0, -4.0, -5.0]],
            dtype=np.float32,
        )
        sample_rate = 4000

        augment = Shift(
            min_fraction=-0.5,
            max_fraction=-0.5,
            rollover=True,
            fade=True,
            fade_duration=1.0,
            p=1.0,
        )
        processed_samples = augment(samples=samples, sample_rate=sample_rate)
        assert_almost_equal(
            processed_samples,
            np.array(
                [
                    [0.0015004, 0.0010003, 0.0, 0.0, 0.0005001],
                    [-0.0015004, -0.0010003, -0.0, -0.0, -0.0005001],
                ],
                dtype=np.float32,
            ),
        )
コード例 #3
0
    def test_shift_multichannel(self):
        samples = np.array(
            [[0.75, 0.5, -0.25, -0.125], [0.9, 0.5, -0.25, -0.125]],
            dtype=np.float32)
        sample_rate = 4000

        augment = Shift(min_fraction=0.5, max_fraction=0.5, p=1.0)
        processed_samples = augment(samples=samples, sample_rate=sample_rate)

        assert_almost_equal(
            processed_samples,
            np.array(
                [[-0.25, -0.125, 0.75, 0.5], [-0.25, -0.125, 0.9, 0.5]],
                dtype=np.float32,
            ),
        )
        self.assertEqual(processed_samples.dtype, np.float32)
コード例 #4
0
    def test_randomize_parameters_and_apply(self):
        samples = 1.0 / np.arange(1, 21, dtype=np.float32)
        sample_rate = 44100

        augmenter = Compose([
            AddBackgroundNoise(
                sounds_path=os.path.join(DEMO_DIR, "background_noises"),
                min_snr_in_db=15,
                max_snr_in_db=35,
                p=1.0,
            ),
            ClippingDistortion(p=0.5),
            FrequencyMask(min_frequency_band=0.3,
                          max_frequency_band=0.5,
                          p=0.5),
            TimeMask(min_band_part=0.2, max_band_part=0.5, p=0.5),
            Shift(min_fraction=0.5, max_fraction=0.5, p=0.5),
        ])
        augmenter.freeze_parameters()
        augmenter.randomize_parameters(samples=samples,
                                       sample_rate=sample_rate)

        parameters = [
            transform.parameters for transform in augmenter.transforms
        ]

        perturbed_samples1 = augmenter(samples=samples,
                                       sample_rate=sample_rate)
        perturbed_samples2 = augmenter(samples=samples,
                                       sample_rate=sample_rate)

        assert_array_equal(perturbed_samples1, perturbed_samples2)

        augmenter.unfreeze_parameters()

        for transform_parameters, transform in zip(parameters,
                                                   augmenter.transforms):
            self.assertTrue(transform_parameters == transform.parameters)
            self.assertFalse(transform.are_parameters_frozen)
コード例 #5
0
    def test_shift_fade_rollover(self):
        samples = np.array(
            [[1.0, 2.0, 3.0, 4.0, 5.0], [-1.0, -2.0, -3.0, -4.0, -5.0]],
            dtype=np.float32,
        )
        sample_rate = 4000

        augment = Shift(
            min_fraction=0.5,
            max_fraction=0.5,
            rollover=True,
            fade=True,
            fade_duration=0.00075,  # 0.00075 * 4000 = 3
            p=1.0,
        )
        processed_samples = augment(samples=samples, sample_rate=sample_rate)
        assert_almost_equal(
            processed_samples,
            np.array(
                [[2.0, 0.0, 0, 1.0, 3.0], [-2.0, 0.0, 0, -1.0, -3.0]],
                dtype=np.float32,
            ),
        )