示例#1
0
def channel(filename):
    parser = argparse.ArgumentParser('UCLANESL ASVSpoof2019  model')
    parser.add_argument('--eval',
                        action='store_true',
                        default=False,
                        help='eval mode')  #终端参数的调用
    parser.add_argument('--model_path',
                        type=str,
                        default=None,
                        help='Model checkpoint')
    parser.add_argument('--eval_output',
                        type=str,
                        default=None,
                        help='Path to save the evaluation result')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_epochs', type=int, default=20)  #轮数
    parser.add_argument('--lr', type=float, default=0.00005)
    parser.add_argument('--comment',
                        type=str,
                        default=None,
                        help='Comment to describe the saved mdoel')
    parser.add_argument('--track', type=str, default='logical')
    parser.add_argument('--features', type=str, default='spect')
    parser.add_argument('--is_eval', action='store_true', default=False)
    parser.add_argument('--eval_part', type=int, default=0)
    if not os.path.exists('models'):
        os.mkdir('models')
    args = parser.parse_args()  #把所有add_argument项目返回到args中
    track = args.track
    assert args.features in ['mfcc', 'spect', 'cqcc'], 'Not supported feature'
    model_tag = 'model_{}_{}_{}_{}_{}'.format(track, args.features,
                                              args.num_epochs, args.batch_size,
                                              args.lr)
    if args.comment:
        model_tag = model_tag + '_{}'.format(args.comment)
    model_save_path = os.path.join('models', model_tag)
    assert track in ['logical', 'physical'], 'Invalid track given'
    is_logical = (track == 'logical')
    if not os.path.exists(model_save_path):
        os.mkdir(model_save_path)

    if args.features == 'mfcc':
        feature_fn = compute_mfcc_feats
        model_cls = MFCCModel
    elif args.features == 'spect':
        feature_fn = get_log_spectrum
        model_cls = SpectrogramModel
    elif args.features == 'cqcc':
        feature_fn = None  # cqcc feature is extracted in Matlab script
        model_cls = CQCCModel

    transforms = transforms.Compose([  #transform为预处理函数,实现以下四种预处理
        lambda x: pad(x),  #lambda函数实现简单的输入输出功能
        lambda x: librosa.util.normalize(x),
        lambda x: feature_fn(x),  #计算特征,譬如mfcc
        lambda x: Tensor(x)
    ])
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dev_set = data_utils.ASVDataset(is_train=False,
                                    is_logical=is_logical,
                                    transform=transforms,
                                    feature_name=args.features,
                                    is_eval=True,
                                    eval_part=args.eval_part,
                                    file=filename)
    dev_loader = DataLoader(dev_set, batch_size=args.batch_size, shuffle=True)
    model = model_cls().to(device)

    model.load_state_dict(
        torch.load('/backup1/datas/models/students/epoch_19.pth'))
    pred = evaluate_accuracy(dev_loader, model, device)
    return pred
示例#2
0
        feature_fn = get_log_spectrum
        model_cls = SpectrogramModel
    elif args.features == 'cqcc':
        feature_fn = None  # cqcc feature is extracted in Matlab script
        model_cls = CQCCModel

    transforms = transforms.Compose([ #transform为预处理函数,实现以下四种预处理
        lambda x: pad(x),  #lambda函数实现简单的输入输出功能
        lambda x: librosa.util.normalize(x),
        lambda x: feature_fn(x), #计算特征,譬如mfcc
        lambda x: Tensor(x)
    ])
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    dev_set = data_utils.ASVDataset(is_train=False, is_logical=is_logical,
                                    transform=transforms,
                                    feature_name=args.features, is_eval=args.is_eval, eval_part=args.eval_part)
    dev_loader = DataLoader(dev_set, batch_size=args.batch_size, shuffle=True)
    model = model_cls().to(device)
    print(args)

    if args.model_path:
        model.load_state_dict(torch.load(args.model_path))
        print('Model loaded : {}'.format(args.model_path))

    if args.eval:#是否进入评估阶段,默认为否
        assert args.eval_output is not None, 'You must provide an output path'
        assert args.model_path is not None, 'You must provide model checkpoint'
        produce_evaluation_file(dev_set, model, device, args.eval_output)
        sys.exit(0)
示例#3
0
    print(args)

    
    # load data & run
    if args.eval: 
        # if evaluation
        # 加载模型参数
        if args.model_path:
            model.load_state_dict(torch.load(args.model_path))
            print('Model loaded : {}'.format(args.model_path))
        assert args.eval_output is not None, 'You must provide an output path'
        assert args.model_path is not None, 'You must provide model checkpoint'
        # 加载数据 - dev/eval
        
        dev_set = data_utils.ASVDataset(is_train=False, is_logical=is_logical,
                                        transform=transforms,
                                        feature_name=args.features, is_eval=args.is_eval)
        dev_loader = DataLoader(dev_set, batch_size=args.batch_size, shuffle=True)

        # run
        produce_evaluation_file(dev_set, model, device, args.eval_output)
        

    else:
        # if training
        # 加载数据 - train
        train_set = data_utils.ASVDataset(is_train=True, is_logical=is_logical, 
                                        transform=transforms,
                                        feature_name=args.features)
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
        # 加载数据 - dev 观察训练过程中accuracy