def main(): ############################################################## # Settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--train-dir', required=True, help='train feature dir') parser.add_argument('--train-utt2label', required=True, help='train utt2label') parser.add_argument('--validation-dir', required=True, help='dev feature dir') parser.add_argument('--validation-utt2label', required=True, help='dev utt2label') parser.add_argument('--model-path', help='path to the pretrained model') parser.add_argument('--logging-dir', required=True, help='model save directory') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.001)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--seg', default=None, help='seg method') parser.add_argument('--seg-win', type=int, help='segmented window size') parser.add_argument('--rank', type=int, required=True, help='rank value for decomposition: 2 or 4') parser.add_argument('--decomp-type', default='tucker', help='decomposition type: tucker cp') args = parser.parse_args() torch.cuda.empty_cache() # Init model & Setup logs if args.seg is None: model = AttenResNet4(atten_activation, atten_channel, size1=(257, M)) run_name = "fine_decomp-AFN4-" + str(M) + "-orig-rank_" + str( args.rank) + '-' + args.decomp_type # noqa else: model = AttenResNet4DeformAll(atten_activation, atten_channel, size1=(257, args.seg_win)) # noqa run_name = "fine_decomp-AFN4De-" + str( args.seg_win) + "-" + args.seg + "-rank_" + str( args.rank) + '-' + args.decomp_type # noqa logger = setup_logs(args.logging_dir, run_name) use_cuda = not args.no_cuda and torch.cuda.is_available() logger.info("use_cuda is {}".format(use_cuda)) # Setting random seeds for reproducibility. np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # CUDA determinism device = torch.device("cuda" if use_cuda else "cpu") model = load_model(model, args.model_path) # perform decomposition model = apply_net_decomp(model=model, logger=logger, rank=args.rank, decomp_type=args.decomp_type) model.to(device) ############################################################## # Loading the dataset and fine tune the compressed model params = { 'num_workers': 0, 'pin_memory': False, 'worker_init_fn': np.random.seed(args.seed) } if use_cuda else {} logger.info('===> loading train and dev dataset') training_set = SpoofDataset(args.train_dir, args.train_utt2label) validation_set = SpoofDataset(args.validation_dir, args.validation_utt2label) train_loader = data.DataLoader(training_set, batch_size=batch_size, shuffle=True, **params) # set shuffle to True validation_loader = data.DataLoader(validation_set, batch_size=test_batch_size, shuffle=False, **params) # set shuffle to False optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1) # noqa model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info('#### Decomposed model summary below####\n {}\n'.format( str(model))) logger.info('===> Model total parameter: {}\n'.format(model_params)) ########################################################### # Start training best_eer, best_loss = np.inf, np.inf early_stopping, max_patience = 0, 5 # early stopping and maximum patience total_train_time = [] for epoch in range(1, args.epochs + 1): epoch_timer = timer() # Train and validate train(args, model, device, train_loader, optimizer, epoch, rnn) val_loss, eer = validation(args, model, device, validation_loader, args.validation_utt2label, rnn) scheduler.step(val_loss) # Save if select_best == 'eer': is_best = eer < best_eer best_eer = min(eer, best_eer) best_model = model elif select_best == 'val': is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) best_model = model snapshot( args.logging_dir, run_name, is_best, { 'epoch': epoch + 1, 'best_eer': best_eer, 'state_dict': model.state_dict(), 'validation_loss': val_loss, 'optimizer': optimizer.state_dict(), }) # Early stopping if is_best == 1: early_stopping = 0 else: early_stopping += 1 end_epoch_timer = timer() logger.info("#### End epoch {}/{}, elapsed time: {}".format( epoch, args.epochs, end_epoch_timer - epoch_timer)) # noqa total_train_time.append(end_epoch_timer - epoch_timer) if early_stopping == max_patience: break logger.info("#### Avg. training+validation time per epoch: {}".format( np.average(total_train_time))) # noqa ########################################################### logger.info("#### fine-tuned decomp model size (MB): {}".format( get_size_of_model(best_model))) model_params = sum(p.numel() for p in best_model.parameters() if p.requires_grad) logger.info( '#### non-zero params after fine-tuning: {}'.format(model_params)) logger.info( "################## Done fine-tuning decomp model ######################" )
def main(): ############################################################## ## Settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--eval-scp', help='kaldi eval scp file') parser.add_argument('--eval-utt2label', help='train utt2label') parser.add_argument('--model-path', help='trained model') parser.add_argument('--logging-dir', required=True, help='model save directory') parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', help='input batch size for testing (default: 100)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--plot-dir', help='directory to save plots') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() print('use_cuda is', use_cuda) # Global timer global_timer = timer() # Setup logs logger = setup_logs(args.logging_dir, run_name) # Setting random seeds for reproducibility. torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") model.to(device) ############################################################## ## Loading the dataset params = {'num_workers': 0, 'pin_memory': False} if use_cuda else {} logger.info('===> loading eval dataset') eval_set = SpoofDataset(args.eval_scp, args.eval_utt2label) eval_loader = data.DataLoader(eval_set, batch_size=args.test_batch_size, shuffle=False, **params) # set shuffle to False ################### for multiple models ##################### np.set_printoptions(threshold=np.nan) sum_preds = 0 for model_i in models: logger.info('===> loading {} for prediction'.format(model_i)) checkpoint = torch.load(model_i) model.load_state_dict(checkpoint['state_dict']) model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('model params is', model_params) retrieve_weight(args, model, device, eval_loader, args.eval_scp, args.eval_utt2label, args.plot_dir, rnn) logger.info("===> Final predictions done. Here is a snippet") ########################################################### end_global_timer = timer() logger.info("################## Success #########################") logger.info("Total elapsed time: %s" % (end_global_timer - global_timer))
def main(): ############################################################## # Settings parser = argparse.ArgumentParser(description='Model AFN') parser.add_argument('--train-dir', help='train feature dir') parser.add_argument('--train-utt2label', help='train utt2label') parser.add_argument('--eval-dir', help='eval feature dir') parser.add_argument('--eval-utt2label', help='eval utt2label') parser.add_argument('--model-path', help='path to the pretrained model') parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', help='input batch size for testing (default: 100)') parser.add_argument('--logging-dir', required=True, help='model save directory') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--seg', default=None, help='seg method') parser.add_argument('--seg-win', type=int, help='segmented window size') parser.add_argument('--quant-method', required=True, help='quantization method: dynamic static') args = parser.parse_args() torch.cuda.empty_cache() # Init model & Setup logs if args.seg is None: model = AttenResNet4(atten_activation, atten_channel, size1=(257, M), static_quant=True) run_name = "quant_pred-AFN4-1091-orig" + time.strftime("-%Y_%m_%d") else: if args.seg_win not in (64, 128, 256, 512): raise ValueError("Invalid segment window! Must be 64, 128, 256, or 512") model = AttenResNet4DeformAll(atten_activation, atten_channel, size1=(257, args.seg_win)) # noqa run_name = "quant_pred-AFN4De-" + str(args.seg_win) + "-" + args.seg + time.strftime("-%Y_%m_%d") # noqa logger = setup_logs(args.logging_dir, run_name) logger.info("use_cuda is False. Only runnable on CPU!") # Setting random seeds for reproducibility. np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # CUDA determinism device = torch.device('cpu') model.to(device) ############################################################## # Loading the dataset & the best model params = {} logger.info('===> loading eval dataset: ' + args.eval_utt2label) train_set = SpoofDataset(args.train_dir, args.train_utt2label) train_loader = data.DataLoader( train_set, batch_size=BATCH_SIZE, shuffle=True, **params ) # set shuffle to True eval_set = SpoofDataset(args.eval_dir, args.eval_utt2label) eval_loader = data.DataLoader( eval_set, batch_size=TEST_BATCH_SIZE, shuffle=False, **params ) # set shuffle to False logger.info('===> loading best model for prediction: ' + args.model_path) checkpoint = torch.load( os.path.join(args.model_path), map_location=device ) model.load_state_dict(checkpoint['state_dict']) ############################################################## # apply network quantization & prediction model = apply_net_quant( model=model, logger=logger, quant_method=args.quant_method, calibration_loader=train_loader ) t_start_eval = timer() eval_loss, eval_eer = prediction(args, model, device, eval_loader, args.eval_utt2label, rnn) # noqa t_end_eval = timer() logger.info("#### Total prediction time: {}".format(t_end_eval - t_start_eval)) ########################################################### logger.info("################## Success #########################\n\n")
def main(): ############################################################## # Settings parser = argparse.ArgumentParser(description='Model AFN') parser.add_argument('--eval-dir', help='eval feature dir') parser.add_argument('--eval-utt2label', help='train utt2label') parser.add_argument('--model-path', help='path to the pretrained model') parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', help='input batch size for testing (default: 100)') parser.add_argument('--logging-dir', required=True, help='model save directory') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--seg', default=None, help='seg method') parser.add_argument('--seg-win', type=int, help='segmented window size') parser.add_argument('--decomp-rank', type=int, help='rank value for decomposition') parser.add_argument('--decomp-type', default='tucker', help='decomposition type') args = parser.parse_args() torch.cuda.empty_cache() # Init model & Setup logs if args.seg is None: model = AttenResNet4(atten_activation, atten_channel, size1=(257, M)) run_name = "decomp_pred-AFN4-1091-orig" + time.strftime("-%Y_%m_%d") else: if args.seg_win not in (64, 128, 256, 512): raise ValueError("Invalid segment window! Must be 64, 128, 256, or 512") model = AttenResNet4DeformAll(atten_activation, atten_channel, size1=(257, args.seg_win)) # noqa run_name = "decomp_pred-AFN4De-" + str(args.seg_win) + "-" + args.seg + time.strftime("-%Y_%m_%d") # noqa logger = setup_logs(args.logging_dir, run_name) use_cuda = not args.no_cuda and torch.cuda.is_available() logger.info("use_cuda is {}".format(use_cuda)) # Setting random seeds for reproducibility. np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # CUDA determinism device = torch.device("cuda" if use_cuda else "cpu") model.to(device) ############################################################## # Loading the dataset params = {'num_workers': 0, 'pin_memory': False, 'worker_init_fn': np.random.seed(args.seed)} if use_cuda else {} logger.info('===> loading eval dataset: ' + args.eval_utt2label) eval_set = SpoofDataset(args.eval_dir, args.eval_utt2label) eval_loader = data.DataLoader( eval_set, batch_size=TEST_BATCH_SIZE, shuffle=False, **params ) # set shuffle to False ############################################################## # apply network decomposition and load fine-tuned weights model.to('cpu') model = apply_net_decomp( model=model, logger=logger, rank=args.decomp_rank, decomp_type=args.decomp_type ) logger.info('===> loading fine-tuned model for prediction: ' + args.model_path) checkpoint = torch.load( os.path.join(args.model_path), map_location=device ) model.load_state_dict(checkpoint['state_dict']) model.to(device) t_start_eval = timer() eval_loss, eval_eer = prediction(args, model, device, eval_loader, args.eval_utt2label, rnn) # noqa t_end_eval = timer() logger.info("#### Total prediction time: {}".format(t_end_eval - t_start_eval)) ########################################################### logger.info("################## Success #########################\n\n")
def main(): parser = argparse.ArgumentParser( description='Feature (log-spec) Extraction') parser.add_argument('--data-dir', required=True, help='data directory contains wave files') parser.add_argument('--label-file', required=True, help='protocol file that contains utt2label mapping') parser.add_argument('--feat-dir', required=True, help='feature saving directory') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA for feature extraction') parser.add_argument('--logging-dir', required=True, help='log save directory') parser.add_argument('--segment', action='store_true', default=False, help='whether to segment the logsepc by energy') parser.add_argument( '--seg-win', type=int, help='the window size to be used for segment: 64, 128..') parser.add_argument('--seg-method', help='the method to be used for segment: h, l, hl, lh') args = parser.parse_args() os.makedirs(args.logging_dir, exist_ok=True) os.makedirs(args.feat_dir, exist_ok=True) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") if args.segment and (args.seg_method is None or args.seg_win is None): raise ValueError("segment method or win_size is missing") # Setup logs basename = DATA_NAME + os.path.basename(args.data_dir) run_name = "logspec-" + basename + time.strftime("-%Y-%m-%d") if os.path.exists(run_name + ".log"): os.remove(run_name + ".log") logger = setup_logs(args.logging_dir, run_name) logger.info("===> Start logspec extraction for dataset: " + args.data_dir) # global timer starts global_start = timer() utt2label_path = os.path.join( os.path.dirname(args.label_file), os.path.basename(args.label_file) + '.utt2label') f_utt2label = open(utt2label_path, 'w') f_label = open(args.label_file, 'r') for line in f_label: item = line.strip().split(' ') if item[1] == 'genuine': label = 1 elif item[1] == 'spoof': label = 0 else: raise ValueError("Invalid label: " + item[1]) f_utt2label.write(item[0][:-4] + ' ' + str(label) + '\n') audio_path = os.path.join(args.data_dir, item[0]) t_start = timer() logspec = get_logspec(audio_path, device) if args.segment: if args.seg_method == DEFAULT_SEG: feat = expand_logspec(logspec, M=args.seg_win) elif args.seg_method == TAIL_SEG: feat = torch.flipud( expand_logspec(torch.flipud(logspec), M=args.seg_win)) else: feat = segment_logspec(logspec, args.seg_win, args.seg_method) else: feat = expand_logspec(logspec) t_end = timer() logger.info(item[0] + "\tfeature extraction time: %s" % (t_end - t_start)) f_feat_path = os.path.join(args.feat_dir, item[0][:-4] + '.pt') torch.save(feat, f_feat_path) f_label.close() f_utt2label.close() global_end = timer() logger.info("#### Done logspec extraction for dataset: " + args.data_dir + "####") logger.info("Total elapsed time: %s" % (global_end - global_start))
def main(): ############################################################## # Settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--train-scp', required=True, help='kaldi train scp file') parser.add_argument('--train-utt2label', required=True, help='train utt2label') parser.add_argument('--validation-scp', required=True, help='kaldi dev scp file') parser.add_argument('--validation-utt2label', required=True, help='dev utt2label') parser.add_argument('--eval-scp', help='kaldi eval scp file') parser.add_argument('--eval-utt2label', help='eval utt2label') parser.add_argument('--logging-dir', required=True, help='model save directory') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.001)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging ' 'training status') parser.add_argument('--hidden-dim', type=int, default=100, help='number of neurones in the hidden dimension') parser.add_argument('--plot-wd', help='training plot directory') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() print('use_cuda is', use_cuda) # print('temperature is', temperature) # Global timer global_timer = timer() # Setup logs logger = setup_logs(args.logging_dir, run_name) # Setting random seeds for reproducibility. np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True # CUDA determinism device = torch.device("cuda:0" if use_cuda else "cpu") model.to(device) ############################################################## # Loading the dataset params = { 'num_workers': 16, 'pin_memory': True, 'worker_init_fn': np.random.seed(args.seed) } if use_cuda else {} logger.info('===> loading train and dev dataset') training_set = SpoofDataset(args.train_scp, args.train_utt2label) validation_set = SpoofDataset(args.validation_scp, args.validation_utt2label) train_loader = data.DataLoader(training_set, batch_size=batch_size, shuffle=True, **params) # set shuffle to True validation_loader = data.DataLoader(validation_set, batch_size=test_batch_size, shuffle=False, **params) # set shuffle to False logger.info('===> loading eval dataset') eval_set = SpoofDataset(args.eval_scp, args.eval_utt2label) eval_loader = data.DataLoader(eval_set, batch_size=test_batch_size, shuffle=False, **params) # set shuffle to False optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1) model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info('### Model summary below###\n {}\n'.format(str(model))) logger.info('===> Model total parameter: {}\n'.format(model_params)) ########################################################### # Start training best_eer, best_loss = np.inf, np.inf early_stopping, max_patience = 0, 5 # early stopping and maximum patience print(run_name) for epoch in range(1, args.epochs + 1): epoch_timer = timer() # Train and validate train(args, model, device, train_loader, optimizer, epoch, rnn) # train(args, model, device, train_loader, optimizer, epoch, # args.train_scp, args.train_utt2label, args.plot_wd, rnn=False) val_loss, eer = validation(args, model, device, validation_loader, args.validation_scp, args.validation_utt2label, rnn) scheduler.step(val_loss) # Save if select_best == 'eer': is_best = eer < best_eer best_eer = min(eer, best_eer) elif select_best == 'val': is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) snapshot( args.logging_dir, run_name, is_best, { 'epoch': epoch + 1, 'best_eer': best_eer, 'state_dict': model.state_dict(), 'validation_loss': val_loss, 'optimizer': optimizer.state_dict(), }) # Early stopping if is_best == 1: early_stopping = 0 else: early_stopping += 1 end_epoch_timer = timer() logger.info("#### End epoch {}/{}, " "elapsed time: {}".format(epoch, args.epochs, end_epoch_timer - epoch_timer)) if early_stopping == max_patience: break ########################################################### # Prediction logger.info('===> loading best model for prediction') checkpoint = torch.load( os.path.join(args.logging_dir, run_name + '-model_best.pth')) model.load_state_dict(checkpoint['state_dict']) eval_loss, eval_eer = prediction(args, model, device, eval_loader, args.eval_scp, args.eval_utt2label, rnn) ########################################################### end_global_timer = timer() logger.info("################## Success #########################") logger.info("Total elapsed time: %s" % (end_global_timer - global_timer))
def main(): ############################################################## # Settings parser = argparse.ArgumentParser(description='Model AFN') parser.add_argument('--train-dir', required=True, help='train feature dir') parser.add_argument('--train-utt2label', required=True, help='train utt2label') parser.add_argument('--validation-dir', required=True, help='dev feature dir') parser.add_argument('--validation-utt2label', required=True, help='dev utt2label') parser.add_argument('--eval-dir', help='eval feature dir') parser.add_argument('--eval-utt2label', help='train utt2label') parser.add_argument('--logging-dir', required=True, help='model save directory') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.001)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--seg', default=None, help='seg method') parser.add_argument('--seg-win', type=int, help='segmented window size') args = parser.parse_args() torch.cuda.empty_cache() # Init model & Setup logs if args.seg is None: model = AttenResNet4(atten_activation, atten_channel, size1=(257, M)) run_name = "AFN4" + time.strftime("-%Y_%m_%d-%H_%M_%S-") + str( M) + "-orig" else: model = AttenResNet4DeformAll(atten_activation, atten_channel, size1=(257, args.seg_win)) # noqa run_name = "AFN4De" + time.strftime("-%Y_%m_%d-%H_%M_%S-") + str( args.seg_win) + "-" + args.seg # noqa logger = setup_logs(args.logging_dir, run_name) use_cuda = not args.no_cuda and torch.cuda.is_available() logger.info("use_cuda is {}".format(use_cuda)) # Global timer global_timer = timer() # Setting random seeds for reproducibility. np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # CUDA determinism device = torch.device("cuda" if use_cuda else "cpu") model.to(device) ############################################################## # Loading the dataset params = { 'num_workers': 0, 'pin_memory': False, 'worker_init_fn': np.random.seed(args.seed) } if use_cuda else {} logger.info('===> loading train and dev dataset') training_set = SpoofDataset(args.train_dir, args.train_utt2label) validation_set = SpoofDataset(args.validation_dir, args.validation_utt2label) train_loader = data.DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True, **params) # set shuffle to True validation_loader = data.DataLoader(validation_set, batch_size=TEST_BATCH_SIZE, shuffle=False, **params) # set shuffle to False optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1) # noqa model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info('#### Model summary below ####\n {}\n'.format(str(model))) logger.info('===> Model total # parameter: {}\n'.format(model_params)) ########################################################### # Training best_eer, best_loss = np.inf, np.inf early_stopping, max_patience = 0, 5 # early stopping and maximum patience total_train_time = [] for epoch in range(1, args.epochs + 1): epoch_timer = timer() # Train and validate train(args, model, device, train_loader, optimizer, epoch, rnn) val_loss, eer = validation(args, model, device, validation_loader, args.validation_utt2label, rnn) scheduler.step(val_loss) # Save if SELECT_BEST == 'eer': is_best = eer < best_eer best_eer = min(eer, best_eer) elif SELECT_BEST == 'val': is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) snapshot( args.logging_dir, run_name, is_best, { 'epoch': epoch + 1, 'best_eer': best_eer, 'state_dict': model.state_dict(), 'validation_loss': val_loss, 'optimizer': optimizer.state_dict(), }) # Early stopping if is_best == 1: early_stopping = 0 else: early_stopping += 1 end_epoch_timer = timer() logger.info("#### End epoch {}/{}, elapsed time: {}".format( epoch, args.epochs, end_epoch_timer - epoch_timer)) # noqa total_train_time.append(end_epoch_timer - epoch_timer) if early_stopping == max_patience: break logger.info("#### Avg. training+validation time per epoch: {}".format( np.average(total_train_time))) # noqa ########################################################### # Prediction if args.eval_dir and args.eval_utt2label: logger.info('===> loading eval dataset') eval_set = SpoofDataset(args.eval_dir, args.eval_utt2label) eval_loader = data.DataLoader(eval_set, batch_size=TEST_BATCH_SIZE, shuffle=False, **params) # set shuffle to False logger.info('===> loading best model for prediction') checkpoint = torch.load( os.path.join(args.logging_dir, run_name + '-model_best.pth')) model.load_state_dict(checkpoint['state_dict']) t_start_eval = timer() eval_loss, eval_eer = prediction(args, model, device, eval_loader, args.eval_utt2label, rnn) # noqa end_global_timer = timer() logger.info( "#### Total prediction time: {}".format(end_global_timer - t_start_eval)) # noqa ########################################################### logger.info("################## Success #########################") logger.info("Total elapsed time: %s" % (end_global_timer - global_timer))