import torch.nn as nn from utils import Pairloader, SiameseNet, _tqdm as tqdm from torch.utils.data import DataLoader import os parser = argparse.ArgumentParser(description='Train SiameseNet') parser.add_argument('--save_location', '-sl', type=str, default='model/{}-epoch-{}.pth') parser.add_argument('--epochs', '-e', type=int, default=50) parser.add_argument('--save_every', '-se', type=int, default=5) parser.add_argument('--device', '-d', type=str, default=None) args = parser.parse_args() if not args.device: args.device = 'cuda' if torch.cuda.is_available() else 'cpu' model = SiameseNet(mode='train', device=args.device) datagen = DataLoader(Pairloader(split='train'), shuffle=True) bce_loss = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) for epoch in range(args.epochs): epoch_loss = 0.0 with tqdm(datagen) as t: for i, batch in enumerate(t): t.set_description('EPOCH: %i'%(epoch+1)) data1, data2, label = batch[0][0].to(device=args.device), batch[0][1].to(device=args.device), batch[1].to(device=args.device) optimizer.zero_grad()
import os parser = argparse.ArgumentParser(description='Train SiameseNet') parser.add_argument('--save_location', '-sl', type=str, default='model/{}-epoch-{}.pth') parser.add_argument('--epochs', '-e', type=int, default=50) parser.add_argument('--save_every', '-se', type=int, default=5) parser.add_argument('--device', '-d', type=str, default=None) args = parser.parse_args() if not args.device: args.device = 'cuda' if torch.cuda.is_available() else 'cpu' model = SiameseNet(mode='train', device=args.device) datagen = DataLoader(Pairloader(split='train'), shuffle=True) bce_loss = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) for epoch in range(args.epochs): model.train() epoch_loss = 0.0 with tqdm(datagen) as t: for i, batch in enumerate(t): t.set_description('EPOCH: %i' % (epoch + 1)) img1, img2, label = batch[0][0].to( device=args.device), batch[0][1].to(
refs = { 'up': preprocess(librosa.load(os.path.join(args.ref, 'up.wav'), sr=RATE)[0]), 'down': preprocess(librosa.load(os.path.join(args.ref, 'down.wav'), sr=RATE)[0]), 'sil': preprocess(librosa.load(os.path.join(args.ref, 'sil.wav'), sr=RATE)[0]), 'quit': preprocess(librosa.load(os.path.join(args.ref, 'quit.wav'), sr=RATE)[0]) } print('Loading model') model = SiameseNet(mode='inference', weights_path=args.model_location.format(args.epoch), refs_dict=refs, device=args.device) previous = np.zeros((CHUNK, 1)) audio = pyaudio.PyAudio() stream = audio.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK) print("Recording...") while True:
parser = argparse.ArgumentParser(description='Live test SiameseNet') parser.add_argument('--model_location', '-l', type=str, default='model/{}-epoch-{}.pth') parser.add_argument('--epoch', '-e', type=int, default=None) parser.add_argument('--device', '-d', type=str, default=None) parser.add_argument('--ref', '-r', type=str, default='references/') parser.add_argument('-t', '--target', type=str, default='data/test.wav') args = parser.parse_args() if not args.device: args.device = 'cuda' if torch.cuda.is_available() else 'cpu' print('Loading model') model = SiameseNet().to(device=args.device) model.load_state_dict( torch.load(args.model_location.format('model', args.epoch), map_location=args.device)) model.train() RATE = 16000 def preprocess(audio=None): audio_trimmed = librosa.effects.trim(audio, top_db=7)[0] audio_center = librosa.util.pad_center(audio_trimmed[:4000], 4000) audio_mfcc = librosa.feature.mfcc(y=audio_center, sr=RATE) audio_tensor = torch.tensor(audio_mfcc[None, None]) return audio_tensor.to(device=args.device)
import torch import argparse from utils import Pairloader, SiameseNet from torch.utils.data import DataLoader parser = argparse.ArgumentParser(description='Validate SiameseNet') parser.add_argument('--model_location', '-l', type=str, default='model/model-epoch-{}.pth') parser.add_argument('--epoch', '-e', type=int, default=None) parser.add_argument('--device','-d', type=str, default=None) args = parser.parse_args() if not args.device: args.device = 'cuda' if torch.cuda.is_available() else 'cpu' model = SiameseNet(mode='validate', weights_path=args.model_location.format(args.epoch), device=args.device) datagen = DataLoader(Pairloader(split='valid')) for i, batch in enumerate(datagen): data1, data2, file_names = batch[0][0].to(device=args.device), batch[0][1].to(device=args.device), batch[1] output = model(data1, data2) print(file_names[0], " and ", file_names[1], ":- ", output.item())