Esempio n. 1
0
    def test_train_step(self):
        print(" ---- Test the network backpropagation ----")
        model = FFTNetModel(hid_channels=256, out_channels=256, n_layers=11, cond_channels=80)
        inp = torch.rand(2, 1, 2048)
        c_inp = torch.rand(2, 80, 2048)

        criterion = torch.nn.L1Loss().to(device)

        model.train()
        model_ref = copy.deepcopy(model)
        count = 0
        for param, param_ref in zip(model.parameters(), model_ref.parameters()):
            assert (param - param_ref).sum() == 0, param
            count += 1
        optimizer = optim.Adam(model.parameters(), lr=0.0001)
        for i in range(5):
            out = model(inp, c_inp)
            optimizer.zero_grad()
            loss = criterion(out, torch.zeros(out.shape))
            loss.backward()
            optimizer.step()
        # check parameter changes
        count = 0
        for param, param_ref in zip(model.parameters(), model_ref.parameters()):
            # ignore pre-higway layer since it works conditional
            assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
            count += 1
Esempio n. 2
0
    _ = os.path.dirname(os.path.realpath(__file__))
    OUT_PATH = os.path.join(_, c.output_path)
    OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, True)
    CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
    shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))

    # setup TensorBoard
    tb = SummaryWriter(OUT_PATH)

    # create the FFTNet model
    model = FFTNetModel(hid_channels=256,
                        out_channels=256,
                        n_layers=c.num_quant,
                        cond_channels=80)
    criterion = MaskedCrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=c.lr)

    num_params = count_parameters(model)
    print(" > Models has {} parameters".format(num_params))

    if use_cuda:
        model.cuda()
        criterion.cuda()

    # these two classes extend torch.utils.data.Dataset class to create the batches
    # the batches are tuples of three elements: wav, mels, audio file name
    train_dataset = LJSpeechDataset(
        os.path.join(c.data_path, "mels", "meta_fftnet_train.csv"),
        os.path.join(c.data_path,
                     "mels"), c.sample_rate, c.num_mels, c.num_freq,
        c.min_level_db, c.frame_shift_ms, c.frame_length_ms, c.preemphasis,