def main(): # Data testset = ChainDataset(args.test, train=False, on_the_fly=args.on_the_fly) testloader = AudioDataLoader(testset, batch_size=args.bsz) # Model checkpoint_path = os.path.join(args.exp, args.model) with open(checkpoint_path, 'rb') as f: state = torch.load(f) model_args = state['args'] print("==> creating model '{}'".format(model_args.arch)) model = get_model(model_args.feat_dim, model_args.num_targets, model_args.layers, model_args.hidden_dims, model_args.arch, kernel_sizes=model_args.kernel_sizes, dilations=model_args.dilations, strides=model_args.strides, bidirectional=model_args.bidirectional) print(model) if use_cuda: model = torch.nn.DataParallel(model).cuda() # Load checkpoint. print('==> Resuming from checkpoint..') model.load_state_dict(state['state_dict']) output_file = os.path.join(args.exp, args.results) test(testloader, model, output_file, use_cuda)
def main(): global best_loss writer = SummaryWriter(args.exp) print('Saving model and logs to {}'.format(args.exp)) start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch # Data trainset = ChainDataset(args.train) trainsampler = BucketSampler(trainset, args.train_bsz) trainloader = AudioDataLoader(trainset, batch_sampler=trainsampler) validset = ChainDataset(args.valid) validloader = AudioDataLoader(validset, batch_size=args.valid_bsz) # Model print("==> creating model '{}'".format(args.arch)) model = get_model(args.feat_dim, args.num_targets, args.layers, args.hidden_dims, args.arch, kernel_sizes=args.kernel_sizes, dilations=args.dilations, strides=args.strides, bidirectional=args.bidirectional, dropout=args.dropout) print(model) if use_cuda: model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # loss den_fst = simplefst.StdVectorFst.read(args.den_fst) den_graph = ChainGraph(den_fst, leaky_mode='transition') criterion = ChainLoss(den_graph) # optimizer if args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = optim.Adam( model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) # Resume if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isfile( args.resume), 'Error: no checkpoint directory found!' args.checkpoint = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_loss = checkpoint['best_loss'] start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) # learning rate scheduler if args.scheduler == 'step': scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=args.milestones, gamma=args.gamma, last_epoch=start_epoch - 1) elif args.scheduler == 'exp': gamma = args.gamma ** (1.0 / args.epochs) # final_lr = init_lr * gamma scheduler = lr_scheduler.ExponentialLR( optimizer, gamma=gamma, last_epoch=start_epoch - 1) elif args.scheduler == 'plateau': scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.gamma, patience=1) # Train and val for epoch in range(start_epoch, args.epochs): if epoch >= args.curriculum: trainsampler.shuffle(epoch) train_loss = train( trainloader, model, criterion, optimizer, writer, epoch, use_cuda) valid_loss = test( validloader, model, criterion, writer, epoch, use_cuda) # save model is_best = valid_loss < best_loss best_loss = min(valid_loss, best_loss) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'loss': valid_loss, 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'args': args, }, is_best, exp=args.exp) scheduler.step(valid_loss) print('Best loss:') print(best_loss)
def main(): global best_acc writer = SummaryWriter(args.exp) print('Saving model and logs to {}'.format(args.exp)) start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch # Data trainset = ChainDataset(args.train_feat_dir, args.train_tree_dir) trainloader = AudioDataLoader(trainset, batch_size=args.train_batch) testset = ChainDataset(args.val_feat_dir, args.val_tree_dir) testloader = AudioDataLoader(testset, batch_size=args.test_batch) # Model print("==> creating model '{}'".format(args.arch)) model = get_model(args.feat_dim, args.out_dim, args.layers, args.hidden_dims, args.arch, kernel_sizes=args.kernel_sizes, dilations=args.dilations, bidirectional=args.bidirectional) if use_cuda: model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # loss den_fst = simplefst.StdVectorFst.read(args.den_fst) den_graph = ChainGraph(den_fst, initial='recursive') criterion = ChainLoss(den_graph) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # Resume if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isfile( args.resume), 'Error: no checkpoint directory found!' args.checkpoint = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_acc = checkpoint['best_acc'] start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) # learning rate scheduler if args.scheduler == 'step': scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=args.milestones, gamma=args.gamma, last_epoch=start_epoch - 1) elif args.scheduler == 'exp': gamma = args.gamma ** (1.0 / args.epochs) # final_lr = init_lr * gamma scheduler = lr_scheduler.ExponentialLR( optimizer, gamma=gamma, last_epoch=start_epoch - 1) # Train and val for epoch in range(start_epoch, args.epochs): scheduler.step() train_loss = train( trainloader, model, criterion, optimizer, writer, epoch, use_cuda) test_loss = test( testloader, model, criterion, writer, epoch, use_cuda) # save model is_best = test_loss > best_loss best_loss = max(test_loss, best_loss) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'loss': test_loss, 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), }, is_best, exp=args.exp) print('Best loss:') print(best_loss)