예제 #1
0
def test_train():
    model, _ = _create_model(local_size=1, local_scale=40)
    dataset = SignWaveDataset(
        sampling_length=16000,
        sampling_rate=16000,
        local_padding_length=0,
        local_scale=40,
    )

    def first_hook(o):
        assert o["main/loss"].data > 2

    def last_hook(o):
        assert o["main/loss"].data < 2

    iteration = 500
    train_support(
        batch_size=8,
        use_gpu=True,
        model=model,
        discriminator_model=None,
        dataset=dataset,
        iteration=iteration,
        first_hook=first_hook,
        last_hook=last_hook,
    )

    # save model
    torch.save(
        model.predictor.state_dict(),
        ("/tmp/" f"test_training" f"-speaker_size=0" f"-iteration={iteration}" ".pth"),
    )
예제 #2
0
    def _wrapper(self, bit=10, mulaw=True):
        model = _create_model(local_size=0)
        dataset = SignWaveDataset(
            sampling_rate=sampling_rate,
            sampling_length=sampling_length,
            bit=bit,
            mulaw=mulaw,
        )

        updater, reporter = setup_support(batch_size, gpu, model, dataset)
        trained_nll = _get_trained_nll()

        def _first_hook(o):
            self.assertTrue(o["main/nll_coarse"] > trained_nll)

        def _last_hook(o):
            self.assertTrue(o["main/nll_coarse"] < trained_nll)

        train_support(iteration, reporter, updater, _first_hook, _last_hook)

        # save model
        torch.save(
            model.predictor.state_dict(),
            "/tmp/"
            f"test_training_wavernn"
            f"-bit={bit}"
            f"-mulaw={mulaw}"
            f"-speaker_size=0"
            f"-iteration={iteration}.pth",
        )
예제 #3
0
    def _wrapper(self, bit=10, mulaw=True):
        scale = 4

        model = _create_model(
            local_size=2 * scale,
            local_scale=scale,
        )
        dataset = DownLocalRandomDataset(
            sampling_rate=sampling_rate,
            sampling_length=sampling_length,
            scale=scale,
            bit=bit,
            mulaw=mulaw,
        )

        updater, reporter = setup_support(batch_size, gpu, model, dataset)
        trained_nll = _get_trained_nll()

        def _first_hook(o):
            self.assertTrue(o["main/nll_coarse"] > trained_nll)

        def _last_hook(o):
            self.assertTrue(o["main/nll_coarse"] < trained_nll)

        train_support(iteration, reporter, updater, _first_hook, _last_hook)
예제 #4
0
    def _wrapper(self, to_double=False, bit=10, mulaw=True):
        scale = 4

        model = _create_model(
            local_size=2 * scale,
            local_scale=scale,
        )
        dataset = DownLocalRandomDataset(
            sampling_length=sampling_length,
            scale=scale,
            to_double=to_double,
            bit=bit,
            mulaw=mulaw,
            local_padding_size=0,
        )

        updater, reporter = setup_support(batch_size, gpu, model, dataset)
        trained_nll = _get_trained_nll()

        def _first_hook(o):
            self.assertTrue(o['main/nll_coarse'].data > trained_nll)
            if to_double:
                self.assertTrue(o['main/nll_fine'].data > trained_nll)

        def _last_hook(o):
            self.assertTrue(o['main/nll_coarse'].data < trained_nll)
            if to_double:
                self.assertTrue(o['main/nll_fine'].data < trained_nll)

        train_support(iteration, reporter, updater, _first_hook, _last_hook)
예제 #5
0
    def _wrapper(self, to_double=False, bit=10, mulaw=True):
        speaker_size = 4
        model = _create_model(
            local_size=0,
            speaker_size=speaker_size,
        )

        datasets = [
            SignWaveDataset(
                sampling_rate=sampling_rate,
                sampling_length=sampling_length,
                to_double=to_double,
                bit=bit,
                mulaw=mulaw,
                frequency=(i + 1) * 110,
            ) for i in range(speaker_size)
        ]
        dataset = SpeakerWavesDataset(
            wave_dataset=ConcatenatedDataset(*datasets),
            speaker_nums=list(
                chain.from_iterable([i] * len(d)
                                    for i, d in enumerate(datasets))),
        )

        updater, reporter = setup_support(batch_size, gpu, model, dataset)
        trained_nll = _get_trained_nll()

        def _first_hook(o):
            self.assertTrue(o['main/nll_coarse'].data > trained_nll)
            if to_double:
                self.assertTrue(o['main/nll_fine'].data > trained_nll)

        def _last_hook(o):
            self.assertTrue(o['main/nll_coarse'].data < trained_nll)
            if to_double:
                self.assertTrue(o['main/nll_fine'].data < trained_nll)

        train_support(iteration, reporter, updater, _first_hook, _last_hook)

        # save model
        serializers.save_npz(
            '/tmp/'
            f'test_training_wavernn'
            f'-to_double={to_double}'
            f'-bit={bit}'
            f'-mulaw={mulaw}'
            f'-speaker_size={speaker_size}'
            f'-iteration={iteration}.npz',
            model.predictor,
        )
예제 #6
0
    def _wrapper(self, bit=10, mulaw=True):
        speaker_size = 4
        model = _create_model(
            local_size=0,
            speaker_size=speaker_size,
        )

        datasets = [
            SignWaveDataset(
                sampling_rate=sampling_rate,
                sampling_length=sampling_length,
                bit=bit,
                mulaw=mulaw,
                frequency=(i + 1) * 110,
            ) for i in range(speaker_size)
        ]
        dataset = SpeakerWavesDataset(
            wave_dataset=ConcatDataset(datasets),
            speaker_nums=list(
                chain.from_iterable([i] * len(d)
                                    for i, d in enumerate(datasets))),
        )

        updater, reporter = setup_support(batch_size, gpu, model, dataset)
        trained_nll = _get_trained_nll()

        def _first_hook(o):
            self.assertTrue(o["main/nll_coarse"] > trained_nll)

        def _last_hook(o):
            self.assertTrue(o["main/nll_coarse"] < trained_nll)

        train_support(iteration, reporter, updater, _first_hook, _last_hook)

        # save model
        torch.save(
            model.predictor.state_dict(),
            "/tmp/"
            f"test_training_wavernn"
            f"-bit={bit}"
            f"-mulaw={mulaw}"
            f"-speaker_size={speaker_size}"
            f"-iteration={iteration}.pth",
        )
예제 #7
0
def test_train_conditional_discriminator():
    model, discriminator_model = _create_model(
        local_size=1, local_scale=40, discriminator_type=DiscriminatorType.cgan,
    )
    dataset = SignWaveDataset(
        sampling_length=16000,
        sampling_rate=16000,
        local_padding_length=0,
        local_scale=40,
    )

    def first_hook(o):
        assert o["main/loss"].data > 3
        assert "discriminator/loss" in o

    def last_hook(o):
        assert o["main/loss"].data < 3

    iteration = 500
    train_support(
        batch_size=8,
        use_gpu=True,
        model=model,
        discriminator_model=discriminator_model,
        dataset=dataset,
        iteration=iteration,
        first_hook=first_hook,
        last_hook=last_hook,
    )

    # save model
    torch.save(
        model.predictor.state_dict(),
        (
            "/tmp/"
            f"test_training"
            f"-speaker_size=0"
            f"-iteration={iteration}.pth"
            f"-discriminator_type={DiscriminatorType.cgan}"
        ),
    )
예제 #8
0
def test_train(mulaw: bool):
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    predictor = create_predictor(create_network_config())
    model = Model(
        model_config=create_model_config(),
        predictor=predictor,
        local_padding_length=0,
    )
    init_weights(model, "orthogonal")
    model.to(device)

    dataset = create_sign_wave_dataset(mulaw=mulaw)

    def first_hook(o):
        assert o["main/loss"].data > 0.5

    def last_hook(o):
        assert o["main/loss"].data < 0.5

    train_support(
        batch_size=16,
        device=device,
        model=model,
        dataset=dataset,
        iteration=iteration,
        first_hook=first_hook,
        last_hook=last_hook,
        learning_rate=2e-4,
    )

    # save model
    torch.save(
        model.predictor.state_dict(),
        f"/tmp/test_training-mulaw={mulaw}-iteration={iteration}.pth",
    )
예제 #9
0
    def _wrapper(self, to_double=False, bit=10, mulaw=True):
        model = _create_model(local_size=0)
        dataset = SignWaveDataset(
            sampling_rate=sampling_rate,
            sampling_length=sampling_length,
            to_double=to_double,
            bit=bit,
            mulaw=mulaw,
        )

        updater, reporter = setup_support(batch_size, gpu, model, dataset)
        trained_nll = _get_trained_nll()

        def _first_hook(o):
            self.assertTrue(o['main/nll_coarse'].data > trained_nll)
            if to_double:
                self.assertTrue(o['main/nll_fine'].data > trained_nll)

        def _last_hook(o):
            self.assertTrue(o['main/nll_coarse'].data < trained_nll)
            if to_double:
                self.assertTrue(o['main/nll_fine'].data < trained_nll)

        train_support(iteration, reporter, updater, _first_hook, _last_hook)

        # save model
        serializers.save_npz(
            '/tmp/'
            f'test_training_wavernn'
            f'-to_double={to_double}'
            f'-bit={bit}'
            f'-mulaw={mulaw}'
            f'-speaker_size=0'
            f'-iteration={iteration}.npz',
            model.predictor,
        )