示例#1
0
    def test_scale(self):

        audio_orig = self.sig.clone()
        result = transforms.Scale()(audio_orig)
        self.assertTrue(result.min() >= -1. and result.max() <= 1.)

        maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
        result = transforms.Scale(factor=maxminmax)(audio_orig)
        self.assertTrue((result.min() == -1. or result.max() == 1.)
                        and result.min() >= -1. and result.max() <= 1.)

        repr_test = transforms.Scale()
        self.assertTrue(repr_test.__repr__())
示例#2
0
    def test_scale(self):

        audio_orig = self.sig.clone()
        result = transforms.Scale()(audio_orig)
        self.assertTrue(
            result.min() >= -1. and result.max() <= 1.,
            print("min: {}, max: {}".format(result.min(), result.max())))

        maxminmax = np.abs([audio_orig.min(),
                            audio_orig.max()]).max().astype(np.float)
        result = transforms.Scale(factor=maxminmax)(audio_orig)
        self.assertTrue(
            (result.min() == -1. or result.max() == 1.) and result.min() >= -1.
            and result.max() <= 1.,
            print("min: {}, max: {}".format(result.min(), result.max())))
示例#3
0
 def test_mel2(self):
     audio_orig = self.sig.clone()  # (16000, 1)
     audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
     audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
     spectrogram_torch = transforms.MEL2()(audio_scaled)  # (1, 319, 40)
     self.assertTrue(spectrogram_torch.dim() == 3)
     self.assertTrue(spectrogram_torch.max() <= 0.)
示例#4
0
    def test_mel(self):

        audio = self.sig.clone()
        audio = transforms.Scale()(audio)
        self.assertTrue(len(audio.size()) == 2)
        result = transforms.MEL()(audio)
        self.assertTrue(len(result.size()) == 3)
        result = transforms.BLC2CBL()(result)
        self.assertTrue(len(result.size()) == 3)
示例#5
0
    def get_dataloader(self):
        vx = VOXFORGE(args.data_path,
                      langs=args.languages,
                      label_type="lang",
                      use_cache=args.use_cache,
                      use_precompute=args.use_precompute)
        if self.model_name == "resnet34_conv" or self.model_name == "resnet101_conv":
            T = tat.Compose([
                #tat.PadTrim(self.max_len),
                tat.MEL(n_mels=224),
                tat.BLC2CBL(),
                tvt.ToPILImage(),
                tvt.Resize((224, 224)),
                tvt.ToTensor(),
            ])
            TT = spl_transforms.LENC(vx.LABELS)
        elif self.model_name == "resnet34_mfcc":
            sr = 16000
            ws = 800
            hs = ws // 2
            n_fft = 512  # 256
            n_filterbanks = 26
            n_coefficients = 12
            low_mel_freq = 0
            high_freq_mel = (2595 * math.log10(1 + (sr / 2) / 700))
            mel_pts = torch.linspace(low_mel_freq, high_freq_mel,
                                     n_filterbanks + 2)  # sr = 16000
            hz_pts = torch.floor(700 * (torch.pow(10, mel_pts / 2595) - 1))
            bins = torch.floor((n_fft + 1) * hz_pts / sr)
            td = {
                "RfftPow": spl_transforms.RfftPow(n_fft),
                "FilterBanks": spl_transforms.FilterBanks(n_filterbanks, bins),
                "MFCC": spl_transforms.MFCC(n_filterbanks, n_coefficients),
            }

            T = tat.Compose([
                tat.Scale(),
                #tat.PadTrim(self.max_len, fill_value=1e-8),
                spl_transforms.Preemphasis(),
                spl_transforms.Sig2Features(ws, hs, td),
                spl_transforms.DummyDim(),
                tat.BLC2CBL(),
                tvt.ToPILImage(),
                tvt.Resize((224, 224)),
                tvt.ToTensor(),
            ])
            TT = spl_transforms.LENC(vx.LABELS)
        vx.transform = T
        vx.target_transform = TT
        if args.use_precompute:
            vx.load_precompute(args.model_name)
        dl = data.DataLoader(vx,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=True)
        return vx, dl
示例#6
0
def load_dataset(dataset='VCTK', train_subset=1.0, person_filter=None):

    transfs = transforms.Compose([
        transforms.Scale(),
        prepro.DB_Spec(n_fft=400, hop_t=0.010, win_t=0.025)
    ])

    if dataset == 'VCTK':
        person_filter = [
            'p249', 'p239', 'p276', 'p283', 'p243', 'p254', 'p258', 'p271'
        ]
        train_dataset = vctk_custom_dataset.VCTK('../datasets/VCTK-Corpus/',
                                                 preprocessed=True,
                                                 person_filter=person_filter,
                                                 filter_mode='exclude')
        test_dataset = vctk_custom_dataset.VCTK('../datasets/VCTK-Corpus/',
                                                preprocessed=True,
                                                person_filter=person_filter,
                                                filter_mode='include')
    elif dataset == 'LibriSpeech':
        train_dataset = librispeech_custom_dataset.LibriSpeech(
            '../datasets/LibriSpeech/',
            preprocessed=True,
            split='train',
            person_filter=person_filter,
            filter_mode='include')
        test_dataset = librispeech_custom_dataset.LibriSpeech(
            '../datasets/LibriSpeech/',
            preprocessed=True,
            split='test',
            person_filter=person_filter,
            filter_mode='include')

    indices = list(range(len(train_dataset)))
    split = int(np.floor(len(train_dataset) * train_subset))

    train_sampler = sampler.RandomSampler(
        sampler.SubsetRandomSampler(indices[:split]))
    test_sampler = sampler.RandomSampler(test_dataset)

    kwargs = {'num_workers': 8, 'pin_memory': True} if args.use_cuda else {}
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               sampler=train_sampler,
                                               drop_last=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              sampler=test_sampler,
                                              drop_last=False,
                                              **kwargs)

    return train_loader, test_loader, train_dataset, test_dataset
示例#7
0
    def test_mel2(self):
        top_db = 80.
        s2db = transforms.SpectrogramToDB("power", top_db)

        audio_orig = self.sig.clone()  # (16000, 1)
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
        mel_transform = transforms.MelSpectrogram()
        # check defaults
        spectrogram_torch = s2db(mel_transform(audio_scaled))  # (1, 319, 40)
        self.assertTrue(spectrogram_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
        # check options
        kwargs = {
            "window": torch.hamming_window,
            "pad": 10,
            "ws": 500,
            "hop": 125,
            "n_fft": 800,
            "n_mels": 50
        }
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
        spectrogram2_torch = s2db(mel_transform2(audio_scaled))  # (1, 506, 50)
        self.assertTrue(spectrogram2_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
        spectrogram_stereo = s2db(mel_transform(x_stereo))
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
        # check filterbank matrix creation
        fb_matrix_transform = transforms.MelScale(n_mels=100,
                                                  sr=16000,
                                                  f_max=None,
                                                  f_min=0.,
                                                  n_stft=400)
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
示例#8
0
    def test_mel(self):

        audio = self.sig.clone()
        audio = transforms.Scale()(audio)
        self.assertTrue(audio.dim() == 2)
        result = transforms.MEL()(audio)
        self.assertTrue(result.dim() == 3)
        result = transforms.BLC2CBL()(result)
        self.assertTrue(result.dim() == 3)

        repr_test = transforms.MEL()
        repr_test.__repr__()
        repr_test = transforms.BLC2CBL()
        repr_test.__repr__()
示例#9
0
    def test_compose(self):

        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 1.2)
        maxminmax = np.abs([audio_orig.min(),
                            audio_orig.max()]).max().astype(np.float)

        tset = (transforms.Scale(factor=maxminmax),
                transforms.PadTrim(max_len=length_new))
        result = transforms.Compose(tset)(audio_orig)

        self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.)

        self.assertTrue(result.size(0) == length_new)
示例#10
0
    def test_compose(self):

        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 1.2)
        maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()

        tset = (transforms.Scale(factor=maxminmax),
                transforms.PadTrim(max_len=length_new, channels_first=False))
        result = transforms.Compose(tset)(audio_orig)

        self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)

        self.assertTrue(result.size(0) == length_new)

        repr_test = transforms.Compose(tset)
        self.assertTrue(repr_test.__repr__())
示例#11
0
 def test_mel2(self):
     audio_orig = self.sig.clone()  # (16000, 1)
     audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
     audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
     mel_transform = transforms.MEL2()
     # check defaults
     spectrogram_torch = mel_transform(audio_scaled)  # (1, 319, 40)
     self.assertTrue(spectrogram_torch.dim() == 3)
     self.assertTrue(spectrogram_torch.le(0.).all())
     self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
     # check correctness of filterbank conversion matrix
     self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
     self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
     # check options
     mel_transform2 = transforms.MEL2(window=torch.hamming_window,
                                      pad=10,
                                      ws=500,
                                      hop=125,
                                      n_fft=800,
                                      n_mels=50)
     spectrogram2_torch = mel_transform2(audio_scaled)  # (1, 506, 50)
     self.assertTrue(spectrogram2_torch.dim() == 3)
     self.assertTrue(spectrogram2_torch.le(0.).all())
     self.assertTrue(spectrogram2_torch.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
     self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
     self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
     # check on multi-channel audio
     x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
     spectrogram_stereo = mel_transform(x_stereo)
     self.assertTrue(spectrogram_stereo.dim() == 3)
     self.assertTrue(spectrogram_stereo.size(0) == 2)
     self.assertTrue(spectrogram_stereo.le(0.).all())
     self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
     # check filterbank matrix creation
     fb_matrix_transform = transforms.F2M(n_mels=100,
                                          sr=16000,
                                          f_max=None,
                                          f_min=0.,
                                          n_stft=400)
     self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
     self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
     self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
示例#12
0
 def test_mel2(self):
     audio_orig = self.sig.clone()  # (16000, 1)
     audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
     audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
     mel_transform = transforms.MEL2(window=torch.hamming_window, pad=10)
     spectrogram_torch = mel_transform(audio_scaled)  # (1, 319, 40)
     self.assertTrue(spectrogram_torch.dim() == 3)
     self.assertTrue(spectrogram_torch.le(0.).all())
     self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
     # load stereo file
     x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
     spectrogram_stereo = mel_transform(x_stereo)
     self.assertTrue(spectrogram_stereo.dim() == 3)
     self.assertTrue(spectrogram_stereo.size(0) == 2)
     self.assertTrue(spectrogram_stereo.le(0.).all())
     self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
示例#13
0
    def test_mfcc(self):
        audio_orig = self.sig.clone()
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[2] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[1] == 321)
        # check melkwargs are passed through
        melkwargs = {'ws': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)
        self.assertTrue(torch_mfcc2.shape[1] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate,
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)

        norm_check = torch_mfcc.clone()
        norm_check[:, :, 0] *= math.sqrt(n_mels) * 2
        norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
示例#14
0
# import math

## Setting seed
import random

param.seed = param.seed or random.randint(1, 10000)
print("Random Seed: " + str(param.seed))
print("Random Seed: " + str(param.seed), file=log_output)
random.seed(param.seed)
torch.manual_seed(param.seed)
if param.cuda:
    torch.cuda.manual_seed_all(param.seed)

## Transforming audio files
trans = transf.Compose([
    transf.Scale(),  # This makes it into [-1,1]
    # transf.ToTensor(),
    transf.PadTrim(max_len=param.audio_size),  # I don't know if this is needed
    # This makes it into [-1,1] so tanh will work properly
    # transf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


def load_sound(path):
    tensor_to_load_into = None
    import torchaudio
    out, sample_rate = torchaudio.load(path, tensor_to_load_into)
    return out


## Importing dataset
示例#15
0
    seq_M = args.seq_M
    batch_size = args.batch_size
    depth = args.depth
    radixs = [2] * depth
    N = np.prod(radixs)
    channels = args.channels
    lr = args.lr
    steps = args.steps
    c = args.c
    generation_time = args.file_size
    filename = args.outfile

    maxlen = 50000
    print('==> Downloading YesNo Dataset..')
    transform = transforms.Compose(
        [transforms.Scale(),
         transforms.PadTrim(maxlen),
         transforms.MuLawEncoding(quantization_channels=channels)])
    data = torchaudio.datasets.YESNO('./data', download=True, transform=transform)
    data_loader = DataLoader(data, batch_size=batch_size, num_workers=4, shuffle=True)

    print('==> Building model..')
    net = general_FFTNet(radixs, 128, channels).cuda()

    print(sum(p.numel() for p in net.parameters() if p.requires_grad), "of parameters.")

    optimizer = optim.Adam(net.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()

    print("Start Training.")
    a = datetime.now().replace(microsecond=0)
    seq_M = args.seq_M
    batch_size = args.batch_size
    depth = args.depth
    radixs = [2] * depth
    N = np.prod(radixs)
    channels = args.channels
    lr = args.lr
    steps = args.steps
    c = args.c
    generation_time = args.file_size
    filename = args.outfile
    features_size = args.feature_size

    print('==> Downloading YesNo Dataset..')
    transform = transforms.Compose([transforms.Scale()])
    data = torchaudio.datasets.YESNO('./data',
                                     download=True,
                                     transform=transform)
    data_loader = DataLoader(data, batch_size=1, num_workers=2)

    print('==> Extracting features..')
    train_wav = []
    train_features = []
    train_targets = []
    for batch_idx, (inputs, _) in enumerate(data_loader):
        inputs = inputs.view(-1).numpy()
        targets = np.roll(inputs, shift=-1)

        #h = mfcc(inputs, sr, winlen=winlen, winstep=winstep, numcep=features_size - 1, winfunc=np.hamming)
        x = inputs.astype(float)
示例#17
0
import torch
import torchaudio.datasets as dset
from torchaudio import transforms

transform = transforms.Compose(
    [transforms.Scale(), transforms.PadTrim(100000)])

train_dataset = dset.YESNO("data", transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=10,
)

for i, (input, target) in enumerate(train_loader):
    import ipdb
    ipdb.set_trace(context=21)
    print("HI")
""" Vision MNIST test"""
"""
import torchvision.datasets as vdset
from torchvision import transforms as vtransforms

transform = vtransforms.Compose([
        vtransforms.ToTensor()
        ])

mnist = vdset.MNIST("data", transform=transform, download=True)

mnist_loader = torch.utils.data.DataLoader(
  mnist,
  batch_size=10,
示例#18
0
文件: main.py 项目: jozhang97/WaveApp
def main():
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Data loading code
  # Any other preprocessings? http://pytorch.org/audio/transforms.html
  sample_length = 10000
  scale = transforms.Scale()
  padtrim = transforms.PadTrim(sample_length)
  downmix = transforms.DownmixMono()
  transforms_audio = transforms.Compose([
    scale, padtrim, downmix
  ])

  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)
  train_dir = os.path.join(args.data_path, 'train')
  val_dir = os.path.join(args.data_path, 'val')

  #Choose dataset to use
  if args.dataset == 'arctic':
    # TODO No ImageFolder equivalent for audio. Need to create a Dataset manually
    train_dataset = Arctic(train_dir, transform=transforms_audio, download=True)
    val_dataset = Arctic(val_dir, transform=transforms_audio, download=True)
    num_classes = 4
  elif args.dataset == 'vctk':
    train_dataset = dset.VCTK(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.VCTK(val_dir, transform=transforms_audio, download=True)
    num_classes = 10
  elif args.dataset == 'yesno':
    train_dataset = dset.YESNO(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.YESNO(val_dir, transform=transforms_audio, download=True)
    num_classes = 2
  else:
    assert False, 'Dataset is incorrect'

  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.workers,
    # pin_memory=True, # What is this?
    # sampler=None     # What is this?
  )
  val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)


  #Feed in respective model file to pass into model (alexnet.py)
  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  # net = models.__dict__[args.arch](num_classes)
  net = AlexNet(num_classes)
  #
  print_log("=> network :\n {}".format(net), log)

  # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  # Define stochastic gradient descent as optimizer (run backprop on random small batch)
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

  #Sets use for GPU if available
  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  # Need same python vresion that the resume was in 
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      if args.ngpu == 0:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
      else:
        checkpoint = torch.load(args.resume)

      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(val_loader, net, criterion, 0, log, val_dataset)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()

  # Training occurs here
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    print("One epoch")
    # train for one epoch
    # Call to train (note that our previous net is passed into the model argument)
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, train_dataset)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(val_loader, net, criterion, epoch, log, val_dataset)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
    }, is_best, args.save_path, 'checkpoint.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

  log.close()
示例#19
0
def evaluate():
  num_classes = 4

  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Any other preprocessings? http://pytorch.org/audio/transforms.html
  sample_length = 10000
  scale = transforms.Scale()
  padtrim = transforms.PadTrim(sample_length)
  transforms_audio = transforms.Compose([
    scale, padtrim
  ])


  # Data loading
  fs, data = wavfile.read(args.file_name)
  data = torch.from_numpy(data).float()
  data = data.unsqueeze(1)
  audio = transforms_audio(data)
  audio = Variable(audio)
  audio = audio.view(1, -1)
  audio = audio.unsqueeze(0)


  #Feed in respective model file to pass into model (alexnet.py)
  print_log("=> creating model '{}'".format(args.arch), log)

  # Init model, criterion, and optimizer
  # net = models.__dict__[args.arch](num_classes)
  net = AlexNet(num_classes)
  print_log("=> network :\n {}".format(net), log)


  #Sets use for GPU if available
  if args.use_cuda:
    net.cuda()

  # optionally resume from a checkpoint
  # Need same python version that the resume was in
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      if args.ngpu == 0:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
      else:
        checkpoint = torch.load(args.resume)

      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  net.eval()
  if args.use_cuda:
    audio = audio.cuda()
  output = net(audio)
  print(output)
  # TODO postprocess output to a string representing the person speaking
  # ouptut = val_dataset.postprocess_target(output)
  return
示例#20
0
n_coefficients = 12
low_mel_freq = 0
high_freq_mel = (2595 * np.log10(1 + (sr / 2) / 700))
mel_pts = np.linspace(low_mel_freq, high_freq_mel, n_filterbanks + 2)
hz_pts = np.floor(700 * (10**(mel_pts / 2595) - 1))
bins = np.floor((n_fft + 1) * hz_pts / sr)

# data transformations
td = {
    "RfftPow": RfftPow(n_fft),
    "FilterBanks": FilterBanks(n_filterbanks, bins),
    "MFCC": MFCC(n_filterbanks, n_coefficients),
}

transforms = tat.Compose([
    tat.Scale(),
    tat.PadTrim(58000, fill_value=1e-8),
    Preemphasis(),
    Sig2Features(ws, hs, td),
])

# set network parameters
use_cuda = torch.cuda.is_available()
batch_size = args.batch_size
input_features = 26
hidden_size = 100
output_size = 3
#output_length = (8 + 7 + 2) # with "blanks"
output_length = 8  # without blanks
n_layers = 1
attn_modus = "dot"