Esempio n. 1
0
    def test_downmix_mono(self):

        audio_L = self.sig.clone()
        audio_R = self.sig.clone()
        R_idx = int(audio_R.size(0) * 0.1)
        audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))

        audio_Stereo = torch.cat((audio_L, audio_R), dim=1)

        self.assertTrue(audio_Stereo.size(1) == 2)

        result = transforms.DownmixMono(channels_first=False)(audio_Stereo)

        self.assertTrue(result.size(1) == 1)

        repr_test = transforms.DownmixMono(channels_first=False)
        self.assertTrue(repr_test.__repr__())
Esempio n. 2
0
def main():
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Data loading code
  # Any other preprocessings? http://pytorch.org/audio/transforms.html
  sample_length = 10000
  scale = transforms.Scale()
  padtrim = transforms.PadTrim(sample_length)
  downmix = transforms.DownmixMono()
  transforms_audio = transforms.Compose([
    scale, padtrim, downmix
  ])

  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)
  train_dir = os.path.join(args.data_path, 'train')
  val_dir = os.path.join(args.data_path, 'val')

  #Choose dataset to use
  if args.dataset == 'arctic':
    # TODO No ImageFolder equivalent for audio. Need to create a Dataset manually
    train_dataset = Arctic(train_dir, transform=transforms_audio, download=True)
    val_dataset = Arctic(val_dir, transform=transforms_audio, download=True)
    num_classes = 4
  elif args.dataset == 'vctk':
    train_dataset = dset.VCTK(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.VCTK(val_dir, transform=transforms_audio, download=True)
    num_classes = 10
  elif args.dataset == 'yesno':
    train_dataset = dset.YESNO(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.YESNO(val_dir, transform=transforms_audio, download=True)
    num_classes = 2
  else:
    assert False, 'Dataset is incorrect'

  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.workers,
    # pin_memory=True, # What is this?
    # sampler=None     # What is this?
  )
  val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)


  #Feed in respective model file to pass into model (alexnet.py)
  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  # net = models.__dict__[args.arch](num_classes)
  net = AlexNet(num_classes)
  #
  print_log("=> network :\n {}".format(net), log)

  # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  # Define stochastic gradient descent as optimizer (run backprop on random small batch)
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

  #Sets use for GPU if available
  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  # Need same python vresion that the resume was in 
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      if args.ngpu == 0:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
      else:
        checkpoint = torch.load(args.resume)

      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(val_loader, net, criterion, 0, log, val_dataset)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()

  # Training occurs here
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    print("One epoch")
    # train for one epoch
    # Call to train (note that our previous net is passed into the model argument)
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, train_dataset)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(val_loader, net, criterion, epoch, log, val_dataset)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
    }, is_best, args.save_path, 'checkpoint.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

  log.close()