def main(): #Train settings wandb.init(project="vgg_triplet") global args, best_acc parser = argparse.ArgumentParser( description='VGG Triplet-Loss Speaker Embedding') parser.add_argument('--batch-size', type=int, default=32, metavar='N', help='input batch size for training') parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', help='input batch size for testing') parser.add_argument('--epochs', type=int, default=50, metavar='N', help='number of epochs for training') parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum') parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training') parser.add_argument('--seed', type=int, default=2, metavar='S', help='random seed') parser.add_argument( '--log-interval', type=int, default=20, metavar='N', help='how many batches to wait before logging training score') parser.add_argument('--margin', type=float, default=2, metavar='M', help='margin for triplet loss (default: 0.2)') parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: None)') parser.add_argument('--name', default='TripletNet_RMSprop', type=str, help='name of experiment') parser.add_argument( '--base-path', type=str, default= '/home/lucas/PycharmProjects/Papers_with_code/data/AMI/amicorpus_individual/Extracted_Speech', help='string to triplets') parser.add_argument('--ap-file', default='anchor_pairs.txt', type=str, help='name of file with anchor-positive pairs') parser.add_argument('--s-file', default='trimmed_sample_list.txt', type=str, help='name of sample list') parser.add_argument( '--save-path', default= '/home/lucas/PycharmProjects/Papers_with_code/data/models/VGG_Triplet', type=str, help='path to save models to') parser.add_argument('--save', type=bool, default=True, help='save model?') parser.add_argument('--load', type=bool, default=False, help='load model from latest checkpoint') args = parser.parse_args() wandb.run.name = args.name wandb.run.save() wandb.config.update(args) args.cuda = not args.no_cuda and torch.cuda.is_available() print(args.no_cuda) print(torch.cuda.is_available()) if args.cuda: device = torch.device("cuda:0") else: device = torch.device("cpu") torch.manual_seed(args.seed) kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {} #train_loader = torch.utils.data.DataLoader(TripletLoader(base_path=args.base_path,anchor_positive_pairs=args.ap_file,sample_list=args.s_file,train=True), # batch_size=args.batch_size,shuffle=True,**kwargs) #test_loader = torch.utils.data.DataLoader(TripletLoader(base_path=args.base_path,anchor_positive_pairs=args.ap_file,sample_list=args.s_file,train=False), # batch_size=args.test_batch_size,shuffle=True,**kwargs) #single_train_loader = torch.utils.data.DataLoader(Spectrogram_Loader(base_path=args.base_path, anchor_positive_pairs=args.ap_file, sample_list=args.s_file, train=True), batch_size=args.batch_size, shuffle=True, **kwargs) #single_test_loader = torch.utils.data.DataLoader(Spectrogram_Loader(base_path=args.base_path, anchor_positive_pairs=args.ap_file, sample_list=args.s_file, train=False), batch_size=args.test_batch_size, shuffle=True, **kwargs) train_time_loader = torch.utils.data.DataLoader(Triplet_Time_Loader( path=os.path.join(args.base_path, args.s_file), train=True), batch_size=args.batch_size, shuffle=True, **kwargs) test_time_loader = torch.utils.data.DataLoader( Triplet_Time_Loader(path=os.path.join(args.base_path, args.s_file), train=False), batch_size=args.test_batch_size, shuffle=True, **kwargs) #global plotter #plotter = VisdomLinePlotter(env_name=args.name) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 16, kernel_size=7) self.conv2 = nn.Conv2d(16, 16, kernel_size=7) self.bn_1 = nn.BatchNorm2d(16) self.conv3 = nn.Conv2d(16, 32, kernel_size=7) self.conv4 = nn.Conv2d(32, 32, kernel_size=7) self.bn_2 = nn.BatchNorm2d(32) self.conv2_drop = nn.Dropout2d(p=0.2) self.fc1 = nn.Linear(448, 256) self.fc2 = nn.Linear(256, 256) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(F.max_pool2d(self.bn_1(self.conv2(x)), 7)) x = F.relu(self.conv3(x)) x = F.relu( F.max_pool2d(self.conv2_drop(self.bn_2(self.conv4(x))), 7)) #print("SIZE ",x.size()) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) return self.fc2(x) model = VGGVox() if args.cuda: model.to(device) if args.load: model.load_state_dict(torch.load(args.save_path)) print("Model loaded from state dict") #tnet = TripletNet(model) #if args.cuda: # tnet.to(device) wandb.watch(model) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True #criterion = torch.nn.MarginRankingLoss(margin = args.margin) criterion = nn.TripletMarginLoss(margin=args.margin, p=2) #optimizer = optim.Adam(tnet.parameters(),lr=args.lr) #optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) optimizer = optim.RMSprop(model.parameters(), lr=args.lr, alpha=0.8, momentum=args.momentum) #n_parameters = sum([p.data.nelement() for p in tnet.parameters()]) #print(' + NUmber of params: {}'.format(n_parameters)) for epoch in range(1, args.epochs + 1): start_time = time.time() train_batch(train_time_loader, model, optimizer, epoch) test_batch(test_time_loader, model, epoch) duration = time.time() - start_time print("Done training epoch {} in {:.4f}".format(epoch, duration)) #for epoch in range(1, args.epochs + 1): # test_batch(single_train_loader, model, epoch) if args.save: torch.save(model.state_dict(), args.save_path) print("Model Saved")
def main(): global args #wandb.login() wandb.init(project="vgg_triplet") config = wandb.config parser = argparse.ArgumentParser( description="VGGVox CNN with Spectrograms for Speaker Verification") parser.add_argument('--test-batch-size', type=int, default=40, metavar='N', help='input batch size for training') parser.add_argument('--train-batch-size', type=int, default=40, metavar='N', help='input batch size for testing') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs') parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate') parser.add_argument('--seed', type=int, default=1234, metavar='S', help='random seed') parser.add_argument('--margin', type=float, default=8, metavar='M', help='margin for triplet loss (default: 0.2)') parser.add_argument('--name', default='VGG_Spectogram_Triplet', type=str, help='name of network') parser.add_argument( '--train-set', default= '/home/lucas/PycharmProjects/Data/pyannote/Extracted_Speech/trimmed_sample_list_train.txt', type=str, help='path to train samples') parser.add_argument( '--test-set', default= '/home/lucas/PycharmProjects/Data/pyannote/Extracted_Speech/trimmed_sample_list_test.txt', type=str, help='path to test samples') parser.add_argument( '--valid-set', default= '/home/lucas/PycharmProjects/Data/pyannote/Extracted_Speech/trimmed_sample_list_valid.txt', type=str, help='path to validation samples') parser.add_argument( '--model-path', default= '/home/lucas/PycharmProjects/MetricEmbeddingNet/models/VGG_Spectrogram_Triplet.pt', type=str, help='path to where models are saved/loaded') parser.add_argument('--save-model', type=bool, default=True, help='save model?') parser.add_argument('--load-model', type=bool, default=False, help='load model?') parser.add_argument('--melspectrogram', type=bool, default=False, help='use melspectrogram?') args = parser.parse_args() wandb.config.update(args) torch.manual_seed(config.seed) random.seed(config.seed) np.random.seed(config.seed) torch.backends.cudnn.deterministic = True device = torch.device('cuda:0') kwargs = {'num_workers': 6, 'pin_memory': True} train_loader = data.DataLoader(Spectrogram_Loader(filename=args.train_set, mel=False), batch_size=config.train_batch_size, shuffle=True, **kwargs) test_loader = data.DataLoader(Spectrogram_Loader(filename=args.test_set, mel=False), batch_size=config.test_batch_size, shuffle=True, **kwargs) valid_loader = data.DataLoader(Spectrogram_Loader(filename=args.valid_set, mel=False), batch_size=config.test_batch_size, shuffle=True, **kwargs) model = VGGVox() model.to(device) if args.load_model: try: model.load_state_dict(torch.load(args.model_path)) except: print("Could not load model {} not found".format(args.model_path)) #nn.init.xavier_uniform(model.parameters()) #optimizer = optim.Adam(model.parameters(), lr = config.lr) optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.8, momentum=0.5) #optimizer = optim.Adam(model.parameters(), lr=0.0001, eps=1e-3, amsgrad=True) wandb.watch(model) for epoch in range(1, config.epochs + 1): start_time = time.time() train_loss, train_acc = train(train_loader=train_loader, model=model, optimizer=optimizer, device=device, epoch=epoch) test_loss, test_acc = test(data_loader=test_loader, model=model, device=device, epoch=epoch) #valid_loss, valid_acc = test(data_loader=valid_loader, model=model, device=device, epoch=epoch) print('Finished epoch {} in {:.2f} '.format( epoch, (time.time() - start_time))) wandb.log({ 'Train Loss': train_loss, 'Train Accuracy': train_acc, 'Test Loss': test_loss, 'Test Accuracy': test_acc }) # 'Validation Loss': valid_loss, 'Validation Accuracy': valid_acc}) if config.save_model and (epoch % 20 == 0): torch.save(model.state_dict(), config.model_path) print("Model saved after {} epochs".format(epoch)) plot_validation_set(valid_loader=valid_loader, model=model, epoch=epoch) print("Validation plot saved")