예제 #1
0
def synthesize(z, rate=1):

    vqvae = VQVAE(num_layers=2, z_dim=1, num_class=4, input_linguistic_dim = 289+2).to(device)
    vqvae.load_state_dict(torch.load('static/model/vqvae_model_40.pth', map_location=torch.device(device)))


    data = [np.loadtxt('static/data/ling_F_chicago.csv'), np.loadtxt('static/data/acou_F_chicago.csv'), np.loadtxt('static/data/squeezed_mora_index_chicago.csv').reshape(-1),]#水をマレーシアから買わなくてはな
    #水をマレーシアから買わなくてはならないのですのデータ

    z_tf = np.array([class2value(int(cl), vqvae) for cl in z]).reshape(-1, 1)





    with torch.no_grad():
        linguistic_f = data[0]
        linguistic_f = np.concatenate((linguistic_f[:, :285], linguistic_f[:, -4:], np.ones((linguistic_f.shape[0], 1)), np.zeros((linguistic_f.shape[0], 1))), axis=1)
        linguistic_f = torch.from_numpy(linguistic_f).float().to(device)
        pred_lf0 = vqvae.decode(torch.from_numpy(z_tf).float().to(device), linguistic_f, data[2], tokyo=False).cpu().numpy().reshape(-1)


    y_base = data[1].copy()

    y_base[:, lf0_start_idx] = pred_lf0
    y_base[:, lf0_start_idx+1:lf0_start_idx+3] = 0

    waveform = gen_waveform(y_base)

    filepath = './static/wav/BASIC5000_0001_{}.wav'.format(randomname(10))

    wavfile.write(filepath, rate=int(fs*rate), data=waveform.astype(np.int16))

    return filepath
예제 #2
0
train_mora_index_lists = []
test_mora_index_lists = []
#train_files, test_files = train_test_split(files, test_size=test_size, random_state=random_state)

for i, mora_i in enumerate(mora_index_lists_for_model):
    if (i - 1) % 20 == 0:  #test
        pass
    elif i % 20 == 0:  #valid
        test_mora_index_lists.append(mora_i)
    else:
        train_mora_index_lists.append(mora_i)

model = VQVAE().to(device)

if args.model_path != '':
    model.load_state_dict(torch.load(args.model_path))

optimizer = optim.Adam(model.parameters(), lr=2e-3)  #1e-3

start = time.time()
beta = 0.3


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, z, z_unquantized):
    MSE = F.mse_loss(
        recon_x.view(-1), x.view(-1, ), reduction='sum'
    )  #F.binary_cross_entropy(recon_x.view(-1), x.view(-1, ), reduction='sum')

    with torch.no_grad():
        z_no_grad = z