def val_dataloader(self): return torch.utils.data.DataLoader(Window_Loader(filename=args.valid_set, windowed=True, window_length=0.2, overlap=0.01), batch_size=args.test_batch_size, shuffle=False, **self.kwargs)
def main(): #READ CONFIG FILE options = read_conf() #LOG ON WANDB? log = options.wandb project_name = options.project if log: wandb.init(project='SincNet_Triplet') wandb.run.name = project_name device = torch.device("cuda:0") kwargs = {'num_workers': 4, 'pin_memory': True} #Get data path data_PATH = options.path sincnet_path = options.sincnet_path mlp_path = options.mlp_path load = options.load #train_loader = data.DataLoader(Triplet_Time_Loader(path=data_PATH, spectrogram=False, train=True), batch_size=64, shuffle=True, **kwargs) #test_loader = data.DataLoader(Triplet_Time_Loader(path=data_PATH, spectrogram=False, train=False), batch_size=64, shuffle=True, **kwargs) #get parameters for SincNet and MLP #[cnn] # [cnn] cnn_N_filt = list(map(int, options.cnn_N_filt.split(','))) cnn_len_filt = list(map(int, options.cnn_len_filt.split(','))) cnn_max_pool_len = list(map(int, options.cnn_max_pool_len.split(','))) cnn_use_laynorm_inp = str_to_bool(options.cnn_use_laynorm_inp) cnn_use_batchnorm_inp = str_to_bool(options.cnn_use_batchnorm_inp) cnn_use_laynorm = list(map(str_to_bool, options.cnn_use_laynorm.split(','))) cnn_use_batchnorm = list( map(str_to_bool, options.cnn_use_batchnorm.split(','))) cnn_act = list(map(str, options.cnn_act.split(','))) cnn_drop = list(map(float, options.cnn_drop.split(','))) # [dnn] fc_lay = list(map(int, options.fc_lay.split(','))) fc_drop = list(map(float, options.fc_drop.split(','))) fc_use_laynorm_inp = str_to_bool(options.fc_use_laynorm_inp) fc_use_batchnorm_inp = str_to_bool(options.fc_use_batchnorm_inp) fc_use_batchnorm = list( map(str_to_bool, options.fc_use_batchnorm.split(','))) fc_use_laynorm = list(map(str_to_bool, options.fc_use_laynorm.split(','))) fc_act = list(map(str, options.fc_act.split(','))) # [optimization] lr = float(options.lr) batch_size = int(options.batch_size) N_epochs = int(options.N_epochs) N_batches = int(options.N_batches) N_eval_epoch = int(options.N_eval_epoch) seed = int(options.seed) torch.manual_seed(120) train_loader = data.DataLoader(Window_Loader(path=data_PATH, spectrogram=False, train=True), batch_size=batch_size, shuffle=True, **kwargs) test_loader = data.DataLoader(Window_Loader(path=data_PATH, spectrogram=False, train=False), batch_size=batch_size, shuffle=True, **kwargs) SincNet_args = { 'input_dim': 3200, #3 seconds at 16000Hz 'fs': 16000, 'cnn_N_filt': cnn_N_filt, 'cnn_len_filt': cnn_len_filt, 'cnn_max_pool_len': cnn_max_pool_len, 'cnn_use_laynorm_inp': cnn_use_laynorm_inp, 'cnn_use_batchnorm_inp': cnn_use_batchnorm_inp, 'cnn_use_laynorm': cnn_use_laynorm, 'cnn_use_batchnorm': cnn_use_batchnorm, 'cnn_act': cnn_act, 'cnn_drop': cnn_drop } SincNet_model = SincNet(SincNet_args) SincNet_model.to(device) DNN1_args = { 'input_dim': SincNet_model.out_dim, 'fc_lay': fc_lay, 'fc_drop': fc_drop, 'fc_use_batchnorm': fc_use_batchnorm, 'fc_use_laynorm': fc_use_laynorm, 'fc_use_laynorm_inp': fc_use_laynorm_inp, 'fc_use_batchnorm_inp': fc_use_batchnorm_inp, 'fc_act': fc_act } MLP_net = MLP(DNN1_args) MLP_net.to(device) if load: try: SincNet_model.load_state_dict(torch.load(sincnet_path)) MLP_net.load_state_dict(torch.load(mlp_path)) except: print('Could not load models') if log: wandb.watch(models=SincNet_model) wandb.watch(models=MLP_net) #optimizer_SincNet = optim.RMSprop(params=SincNet_model.parameters(), lr=lr, # alpha=0.8, momentum=0.5) #optimizer_MLP = optim.RMSprop(params=MLP_net.parameters(), lr=lr, alpha=0.8, momentum=0.5) optimizer_SincNet = optim.Adam(params=SincNet_model.parameters(), lr=lr) optimizer_MLP = optim.Adam(params=MLP_net.parameters(), lr=lr) #cudnn.benchmark = True for epoch in range(1, N_epochs + 1): start_time = time.time() train_losses_avg, train_accuracy_avg = train_windowed( epoch=epoch, train_loader=train_loader, SincNet_model=SincNet_model, MLP_model=MLP_net, optimizer_SincNet=optimizer_SincNet, optimizer_MLP=optimizer_MLP, device=device) duration = time.time() - start_time print( "Done training epoch {} in {:.4f} \t Accuracy {:.2f} Loss {:.4f}". format(epoch, duration, train_accuracy_avg, train_losses_avg)) test_losses_avg, test_accuracy_avg = test_windowed( test_loader=test_loader, SincNet_model=SincNet_model, MLP_model=MLP_net, epoch=epoch, device=device) if log: wandb.log({ "Train Accuracy": train_accuracy_avg, "Train Loss": train_losses_avg, "Test Accuracy": test_accuracy_avg, "Test Loss": test_losses_avg }) if (epoch % 5) == 0: torch.save(SincNet_model.state_dict(), sincnet_path) torch.save(MLP_net.state_dict(), mlp_path) print("Model saved after {} epochs".format(epoch))
def main(): global args options = read_conf() wandb.init(project="sincnet_triplet") config = wandb.config parser = argparse.ArgumentParser( description="SincNet Speaker Recognition from Raw Waveform") parser.add_argument('--test-batch-size', type=int, default=128, metavar='N', help='input batch size for training') parser.add_argument('--train-batch-size', type=int, default=64, metavar='N', help='input batch size for testing') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs') parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate') parser.add_argument('--seed', type=int, default=1234, metavar='S', help='random seed') parser.add_argument('--margin', type=float, default=8, metavar='M', help='margin for triplet loss (default: 0.2)') parser.add_argument('--name', default='VGG_Spectogram_Triplet', type=str, help='name of network') parser.add_argument( '--train-set', default= '/home/lucvanwyk/Data/pyannote/Extracted_Speech/trimmed_sample_list_train.txt', type=str, help='path to train samples') parser.add_argument( '--test-set', default= '/home/lucvanwyk/Data/pyannote/Extracted_Speech/trimmed_sample_list_test.txt', type=str, help='path to test samples') parser.add_argument( '--valid-set', default= '/home/lucvanwyk/Data/pyannote/Extracted_Speech/trimmed_sample_list_valid.txt', type=str, help='path to validation samples') parser.add_argument( '--model-path-sincnet', default='/home/lucvanwyk/MetricEmbeddingNet/models/SincNet_Triplet', type=str, help='path to where sincnet models are saved/loaded') parser.add_argument( '--model-path-mlp', default='/home/lucvanwyk/MetricEmbeddingNet/models/MLP_Triplet', type=str, help='path to where mlp models are saved/loaded') parser.add_argument('--save-model', type=bool, default=True, help='save model?') parser.add_argument('--load-model', type=bool, default=False, help='load model?') parser.add_argument('--cfg', type=str, default='SincNet_options_Teapot.cfg', help='configuration file') args = parser.parse_args() wandb.config.update(args) torch.manual_seed(config.seed) random.seed(config.seed) np.random.seed(config.seed) torch.backends.cudnn.deterministic = True device = torch.device('cuda:0') kwargs = {'num_workers': 8, 'pin_memory': True} #train_loader = data.DataLoader(Window_Loader(filename=args.train_set, window_length=0.2, overlap=0.01), # batch_size=args.train_batch_size, shuffle=True) train_loader = data.DataLoader(Window_Loader( filename= '/home/lucas/PycharmProjects/Data/pyannote/Extracted_Speech/trimmed_sample_list_train.txt', windowed=True, window_length=0.2, overlap=0.01), batch_size=16, shuffle=True) #test_loader = data.DataLoader(Window_Loader(filename=args.train_set,windowed=True, window_length=0.2, overlap=0.01), # batch_size=args.test_batch_size, shuffle=True ) #valid_loader = data.DataLoader(Window_Loader(filename=args.valid_set,windowed=True, window_length=0.2, overlap=0.01), # batch_size=args.test_batch_size, shuffle=True) # get parameters for SincNet and MLP # [cnn] # [cnn] cnn_N_filt = list(map(int, options.cnn_N_filt.split(','))) cnn_len_filt = list(map(int, options.cnn_len_filt.split(','))) cnn_max_pool_len = list(map(int, options.cnn_max_pool_len.split(','))) cnn_use_laynorm_inp = str_to_bool(options.cnn_use_laynorm_inp) cnn_use_batchnorm_inp = str_to_bool(options.cnn_use_batchnorm_inp) cnn_use_laynorm = list(map(str_to_bool, options.cnn_use_laynorm.split(','))) cnn_use_batchnorm = list( map(str_to_bool, options.cnn_use_batchnorm.split(','))) cnn_act = list(map(str, options.cnn_act.split(','))) cnn_drop = list(map(float, options.cnn_drop.split(','))) # [dnn] fc_lay = list(map(int, options.fc_lay.split(','))) fc_drop = list(map(float, options.fc_drop.split(','))) fc_use_laynorm_inp = str_to_bool(options.fc_use_laynorm_inp) fc_use_batchnorm_inp = str_to_bool(options.fc_use_batchnorm_inp) fc_use_batchnorm = list( map(str_to_bool, options.fc_use_batchnorm.split(','))) fc_use_laynorm = list(map(str_to_bool, options.fc_use_laynorm.split(','))) fc_act = list(map(str, options.fc_act.split(','))) SincNet_args = { 'input_dim': 3200, # 3 seconds at 16000Hz 'fs': 16000, 'cnn_N_filt': cnn_N_filt, 'cnn_len_filt': cnn_len_filt, 'cnn_max_pool_len': cnn_max_pool_len, 'cnn_use_laynorm_inp': cnn_use_laynorm_inp, 'cnn_use_batchnorm_inp': cnn_use_batchnorm_inp, 'cnn_use_laynorm': cnn_use_laynorm, 'cnn_use_batchnorm': cnn_use_batchnorm, 'cnn_act': cnn_act, 'cnn_drop': cnn_drop } SincNet_model = SincNet(SincNet_args) SincNet_model.to(device) DNN1_args = { 'input_dim': SincNet_model.out_dim, 'fc_lay': fc_lay, 'fc_drop': fc_drop, 'fc_use_batchnorm': fc_use_batchnorm, 'fc_use_laynorm': fc_use_laynorm, 'fc_use_laynorm_inp': fc_use_laynorm_inp, 'fc_use_batchnorm_inp': fc_use_batchnorm_inp, 'fc_act': fc_act } MLP_net = MLP(DNN1_args) MLP_net.to(device) print('----') print(SincNet_model.out_dim) wandb.watch(models=SincNet_model) wandb.watch(models=MLP_net) if args.load_model: try: SincNet_model.load_state_dict(torch.load(args.model_path_sincnet)) MLP_net.load_state_dict(torch.load(args.model_path_mlp)) except: print('Could not load models') optimizer_SincNet = optim.RMSprop(params=SincNet_model.parameters(), lr=args.lr, momentum=0.5, alpha=0.8) optimizer_MLP = optim.RMSprop(params=MLP_net.parameters(), lr=args.lr, momentum=0.5, alpha=0.8) for epoch in range(1, args.epochs + 1): start_time = time.time() train_loss, train_acc = train_windowed( SincNet_model=SincNet_model, MLP_model=MLP_net, optimizer_SincNet=optimizer_SincNet, optimizer_MLP=optimizer_MLP, device=device, epoch=epoch, train_loader=train_loader) print('Finished training epoch {} loss {:.4f} accuracy {:.2f}'.format( epoch, train_loss, train_acc)) test_loss, test_acc = test_windowed(SincNet_model=SincNet_model, MLP_model=MLP_net, epoch=epoch, device=device, test_loader=test_loader) wandb.log({ 'Train Loss': train_loss, 'Train Accuracy': train_acc, 'Test Loss': test_loss, 'Test Accuracy': test_acc }) print('Finished epoch {} in {:.2f}'.format(epoch, (time.time() - start_time))) if args.save_model and (epoch % 20 == 0): torch.save(SincNet_model.state_dict(), args.model_path_sincnet) torch.save(MLP_net.state_dict(), args.model_path_mlp) print('Model saved after {} epochs'.format(epoch))