Beispiel #1
0
def run(demo_path,models):
    # vis = visdom.Visdom(use_incoming_socket=False)
    if torch.cuda.is_available():
        torch.cuda.set_device(1)
        idx = torch.cuda.current_device()
        print("Current GPU:" + str(idx))
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    test_path = r"/home/panzh/DataSet/Urbandataset/valid"

    demotest=True
    # parameter

    test_batch_size = 1
    loaderType="logmeldelta" #logmeldelta, logmel, stft
    # sound setting
    window_size = 0.02  # 0.02
    window_stride = 0.01  # 0.01
    window_type = 'hamming'
    normalize = True
    allLabels = {'baby_cry': 0, 'car_engine': 1, 'crowd': 2, 'dog_bark': 3, 'gun_shot': 4,
                  'multispeaker': 5, 'scream': 6, 'siren': 7, 'speaking': 8,'stLaughter':9,'telephone'}

    # loading data
    if demotest:
        #todo 把loader_type作为一个list存不同的处理方式
        demo_dataset=demo.wavLoader(test_path,demo_path, allLabels=allLabels,window_size=window_size, window_stride=window_stride, window_type=window_type,
                                 normalize=normalize,loader_type=loaderType)
        demo_loader = torch.utils.data.DataLoader(demo_dataset, batch_size=test_batch_size, shuffle=None, num_workers=1,
                                                  pin_memory=True, sampler=None)
    else:
        test_dataset = wavLoader(test_path, window_size=window_size, window_stride=window_stride, window_type=window_type,
                                 normalize=normalize,loader_type=loaderType)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=None, num_workers=4,
                                                  pin_memory=True, sampler=None)
    modellist = {}
    for arc,weight in models.items():
        if arc == 'LeNet':
            model = LeNet()
            print("Using LeNet")
        elif arc.startswith('VGG'):
            model = VGG(arc)
            print("Using VGG")
        elif arc.startswith('ResNet50'):
            model = resnet50()
            print("Using ResNet50")
        elif arc.startswith("ResNet18"):
            model = resnet18()
            print("Using resnet18")
        elif arc.startswith("CRNN"):
            # model = CRNN(nc=1, nh=96)
            model = CRNN_GRU()
        elif arc.startswith("ResNet101"):
            model = resnet101()
            print("Using resnet101")
        elif arc.startswith("resnext"):
            model = resnext101_32x4d_(pretrained=None)
            print("resnext101_32x4d_")
        else:
            model = LeNet()
        # build model
        if str(device) == "cuda:1" or str(device)=="cuda:0":
            cuda = True
            model = torch.nn.DataParallel(model.cuda(),device_ids=[idx])
            print("Using cuda for model...")
        else:
            cuda = False
        cudnn.benchmark = True
        if os.path.isfile('./checkpoint/' + str(arc) + '.pth'):
            state = torch.load('./checkpoint/' + str(arc) + '.pth')
            print('load pre-trained model of ' + str(arc) + '\n')
            #print(state)
            model.load_state_dict(state['state_dict'])
        modellist.update({model:weight})

    print('\nStart testing...')
    if demotest:
        pre_label=demo_test(demo_loader,modellist, cuda, mode='Test loss', class2index=demo_dataset.getClass2Index())
    else:
        test(test_loader, modellist, cuda, mode='Test loss', class2index=test_dataset.getClass2Index())
        print('Finished!!')
Beispiel #2
0
from custom_wav_loader import wavLoader
import torch

dataset = wavLoader('dataset/test')

test_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=100,
                                          shuffle=None,
                                          num_workers=4,
                                          pin_memory=True,
                                          sampler=None)

for k, (input, label) in enumerate(test_loader):
    print(input.size(), len(label))

## need to check wav file is correctly imported

import visdom
import numpy as np
import matplotlib.pyplot as plt
vis = visdom.Visdom(use_incoming_socket=False)

import librosa
y, sr = librosa.load('test.wav')
melgram = librosa.amplitude_to_db(
    librosa.feature.melspectrogram(y, sr=sr, n_mels=120, n_fft=1024))
# print(melgram[0])
vis.heatmap(melgram, opts=dict(title='test.wav'))
# vis.heatmap(melgram[0:20],opts=dict(title='test.wav'))

import librosa
Beispiel #3
0
def run():
    # vis = visdom.Visdom(use_incoming_socket=False)
    if torch.cuda.is_available():
        torch.cuda.set_device(1)
        idx = torch.cuda.current_device()
        print("Current GPU:" + str(idx))
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    train_path = r"/home/panzh/DataSet/Urbandataset/train"
    valid_path = r"/home/panzh/DataSet/Urbandataset/valid"
    test_path = r"/home/panzh/DataSet/Urbandataset/test"
    demo_path = r"/home/panzh/Downloads/demoAudio/test2.wav"
    demotest = False
    # parameter
    optimizer = 'adadelta'  # adadelta adam SGD
    lr = 0.08  # to do : adaptive lr 0.001
    epochs = 300
    epoch = 1
    momentum = 0.9  # for SGD

    iteration = 0
    patience = int(0.25 * epochs)
    log_interval = 50

    seed = 1234  # random seed
    batch_size = 32  # 100
    test_batch_size = 16
    arc = 'ResNet101'  # LeNet, VGG11, VGG13, VGG16, VGG19' ResNet CRNN resnext
    loaderType = "logmeldelta"  #logmeldelta, logmel, stft
    # sound setting
    window_size = 0.02  # 0.02
    window_stride = 0.01  # 0.01
    window_type = 'hamming'
    normalize = True

    weight = torch.tensor(
        (0.5, 0.25, 0.25, 0.25, 1, 0.6, 1, 1, 0.16, 0.4)).cuda()
    # loading data
    if demotest:
        demo_dataset = demo.wavLoader(test_path,
                                      demo_path,
                                      window_size=window_size,
                                      window_stride=window_stride,
                                      window_type=window_type,
                                      normalize=normalize,
                                      loader_type=loaderType)
        demo_loader = torch.utils.data.DataLoader(demo_dataset,
                                                  batch_size=test_batch_size,
                                                  shuffle=None,
                                                  num_workers=1,
                                                  pin_memory=True,
                                                  sampler=None)
    else:
        train_dataset = wavLoader(train_path,
                                  window_size=window_size,
                                  window_stride=window_stride,
                                  window_type=window_type,
                                  normalize=normalize,
                                  loader_type=loaderType)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=4,
                                                   pin_memory=True,
                                                   sampler=None)
        valid_dataset = wavLoader(valid_path,
                                  window_size=window_size,
                                  window_stride=window_stride,
                                  window_type=window_type,
                                  normalize=normalize,
                                  loader_type=loaderType)
        valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=None,
                                                   num_workers=4,
                                                   pin_memory=True,
                                                   sampler=None)
        test_dataset = wavLoader(test_path,
                                 window_size=window_size,
                                 window_stride=window_stride,
                                 window_type=window_type,
                                 normalize=normalize,
                                 loader_type=loaderType)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=test_batch_size,
                                                  shuffle=None,
                                                  num_workers=4,
                                                  pin_memory=True,
                                                  sampler=None)

    # build model
    if arc == 'LeNet':
        model = LeNet()
        print("LeNet")
    elif arc.startswith('VGG'):
        model = VGG(arc)
        print("VGG")
    elif arc.startswith('ResNet50'):
        model = resnet50()
        print("ResNet50")
    elif arc.startswith("ResNet18"):
        model = resnet18()
        print("resnet18")
    elif arc.startswith("CRNN"):
        #model = CRNN(nc=1, nh=96)
        model = CRNN_GRU()
    elif arc.startswith("ResNet101"):
        model = resnet101()
        print("resnet101")
    elif arc.startswith("resnext"):
        model = resnext101_32x4d_(pretrained=None)
        print("resnext101_32x4d_")
    else:
        model = LeNet()

    if str(device) == "cuda:1" or str(device) == "cuda:0":
        cuda = True
        model = torch.nn.DataParallel(model.cuda(), device_ids=[idx])
        print("Using cuda for model...")
    else:
        cuda = False

    # define optimizer
    if optimizer.lower() == 'adam':  # adadelta
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif optimizer.lower() == 'adadelta':  # adadelta
        optimizer = optim.Adadelta(model.parameters(), lr=lr)
    elif optimizer.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    else:
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    cudnn.benchmark = True
    best_valid_loss = np.inf

    if os.path.isfile('./checkpoint/' + str(arc) + '.pth'):
        state = torch.load('./checkpoint/' + str(arc) + '.pth')
        print('load pre-trained model of ' + str(arc) + '\n')
        #print(state.keys())
        best_valid_loss = state['acc']
        epoch = state['epoch']
        if str(arc) is "resnext1":
            model.load_state_dict(state)
        else:
            model.load_state_dict(state['state_dict'])
        set_lasttime_lr(optimizer, state['lr'])
    # visdom
    # loss_graph = vis.line(Y=np.column_stack([10, 10, 10]), X=np.column_stack([0, 0, 0]),
    #                      opts=dict(title='loss', legend=['Train loss', 'Valid loss', 'Test loss'], showlegend=True,
    #                                xlabel='epoch'))

    # trainint with early stopping

    print('\nStart training...')
    if demotest:
        demo_test(demo_loader,
                  model,
                  cuda,
                  mode='Test loss',
                  class2index=demo_dataset.getClass2Index())
    else:
        while (epoch < epochs + 1) and (iteration < patience):
            #调节学习率
            lr = adjust_learning_rate(optimizer, epoch)
            #开始训练
            train(train_loader, model, optimizer, epoch, cuda, log_interval,
                  weight)
            print('train Finished!!')
            #计算训练验证测试的损失
            train_loss = test(train_loader,
                              model,
                              cuda,
                              mode='Train loss',
                              class2index=train_dataset.getClass2Index())
            valid_loss = test(valid_loader,
                              model,
                              cuda,
                              mode='Valid loss',
                              class2index=valid_dataset.getClass2Index())
            test_loss = test(test_loader,
                             model,
                             cuda,
                             mode='Test loss',
                             class2index=test_dataset.getClass2Index())

            #如果验证集的损失比之前来的小就保存
            if valid_loss > best_valid_loss:
                iteration += 1
                print('\nLoss was not improved, iteration {0}\n'.format(
                    str(iteration)))
            else:
                print('\nSaving model of ' + str(arc) + '\n')
                iteration = 0
                best_valid_loss = valid_loss
                state = {
                    'net': arc,
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'acc': valid_loss,
                    'lr': lr
                }
                if not os.path.isdir('checkpoint'):  # model load should be
                    os.mkdir('checkpoint')
                torch.save(state, './checkpoint/' + str(arc) + '.pth')
            epoch += 1
            # vis.line(Y=np.column_stack([train_loss, valid_loss, test_loss]), X=np.column_stack([epoch, epoch, epoch]),
            #         win=loss_graph, update='append',
            #        opts=dict(legend=['Train loss', 'Valid loss', 'Test loss'], showlegend=True))
    print('Finished!!')
Beispiel #4
0
seed = 1234  # random seed
batch_size = 20  # 100
test_batch_size = 10
arc = 'resnet18'

# sound setting
window_size = 0.01  # 0.02
window_stride = 0.01  # 0.01
window_type = 'hamming'
normalize = True

# loading data
train_dataset = wavLoader(train_path,
                          window_size=window_size,
                          window_stride=window_stride,
                          window_type=window_type,
                          normalize=normalize)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4,
                                           pin_memory=True,
                                           sampler=None)
valid_dataset = wavLoader(valid_path,
                          window_size=window_size,
                          window_stride=window_stride,
                          window_type=window_type,
                          normalize=normalize)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=batch_size,