else: AV_model = AV.AV_model(people_num) train_generator = AVGenerator(trainfile, database_dir_path=database_dir_path, batch_size=batch_size, shuffle=True) val_generator = AVGenerator(valfile, database_dir_path=database_dir_path, batch_size=batch_size, shuffle=True) if NUM_GPU > 1: parallel_model = ModelMGPU(AV_model, NUM_GPU) adam = optimizers.Adam() loss = audio_loss(gamma=gamma_loss, beta=beta_loss, num_speaker=people_num) parallel_model.compile(loss=loss, optimizer=adam) print(AV_model.summary()) parallel_model.fit_generator( generator=train_generator, validation_data=val_generator, epochs=epochs, workers=workers, use_multiprocessing=use_multiprocessing, callbacks=[TensorBoard(log_dir='./log_AV'), checkpoint, rlr], initial_epoch=initial_epoch) if NUM_GPU <= 1: adam = optimizers.Adam() loss = audio_loss(gamma=gamma_loss, beta=beta_loss, num_speaker=people_num) AV_model.compile(optimizer=adam, loss=loss) print(AV_model.summary())
valfile = v.readlines() AV_model = AV.AV_model(people_num) train_loader = DataLoader( AVGenerator(trainfile, database_dir_path=database_dir_path, batch_size=batch_size, shuffle=True)) val_loader = DataLoader( AVGenerator(valfile, database_dir_path=database_dir_path, batch_size=batch_size, shuffle=True)) optimizer = torch.optim.Adam(AV_model.parameters(), lr=1e-4) lossfunc = audio_loss(gamma=gamma_loss, num_speaker=people_num) for epoch in range(0, num_epoch): print(epoch) for batch in train_loader: preds = AV_model(Variable(batch[0])) loss = lossfunc(preds, Variable(batch[1])) print(loss) loss.backward() optimizer.step() if epoch % 10 == 0: torch.save(AV_model.state_dict(), str(epoch) + '.pt') import os import scipy.io.wavfile as wavfile import numpy as np import utils