Esempio n. 1
0
def main():
    ##############################################################
    # Settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--train-dir', required=True, help='train feature dir')
    parser.add_argument('--train-utt2label',
                        required=True,
                        help='train utt2label')
    parser.add_argument('--validation-dir',
                        required=True,
                        help='dev feature dir')
    parser.add_argument('--validation-utt2label',
                        required=True,
                        help='dev utt2label')
    parser.add_argument('--model-path', help='path to the pretrained model')
    parser.add_argument('--logging-dir',
                        required=True,
                        help='model save directory')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--seg', default=None, help='seg method')
    parser.add_argument('--seg-win', type=int, help='segmented window size')
    parser.add_argument('--rank',
                        type=int,
                        required=True,
                        help='rank value for decomposition: 2 or 4')
    parser.add_argument('--decomp-type',
                        default='tucker',
                        help='decomposition type: tucker cp')
    args = parser.parse_args()

    torch.cuda.empty_cache()

    # Init model & Setup logs
    if args.seg is None:
        model = AttenResNet4(atten_activation, atten_channel, size1=(257, M))
        run_name = "fine_decomp-AFN4-" + str(M) + "-orig-rank_" + str(
            args.rank) + '-' + args.decomp_type  # noqa
    else:
        model = AttenResNet4DeformAll(atten_activation,
                                      atten_channel,
                                      size1=(257, args.seg_win))  # noqa
        run_name = "fine_decomp-AFN4De-" + str(
            args.seg_win) + "-" + args.seg + "-rank_" + str(
                args.rank) + '-' + args.decomp_type  # noqa
    logger = setup_logs(args.logging_dir, run_name)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    logger.info("use_cuda is {}".format(use_cuda))

    # Setting random seeds for reproducibility.
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True  # CUDA determinism

    device = torch.device("cuda" if use_cuda else "cpu")
    model = load_model(model, args.model_path)

    # perform decomposition
    model = apply_net_decomp(model=model,
                             logger=logger,
                             rank=args.rank,
                             decomp_type=args.decomp_type)
    model.to(device)
    ##############################################################
    # Loading the dataset and fine tune the compressed model
    params = {
        'num_workers': 0,
        'pin_memory': False,
        'worker_init_fn': np.random.seed(args.seed)
    } if use_cuda else {}

    logger.info('===> loading train and dev dataset')
    training_set = SpoofDataset(args.train_dir, args.train_utt2label)
    validation_set = SpoofDataset(args.validation_dir,
                                  args.validation_utt2label)
    train_loader = data.DataLoader(training_set,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   **params)  # set shuffle to True
    validation_loader = data.DataLoader(validation_set,
                                        batch_size=test_batch_size,
                                        shuffle=False,
                                        **params)  # set shuffle to False

    optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     factor=0.5,
                                                     patience=1)  # noqa

    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info('#### Decomposed model summary below####\n {}\n'.format(
        str(model)))
    logger.info('===> Model total parameter: {}\n'.format(model_params))
    ###########################################################
    # Start training
    best_eer, best_loss = np.inf, np.inf
    early_stopping, max_patience = 0, 5  # early stopping and maximum patience

    total_train_time = []
    for epoch in range(1, args.epochs + 1):
        epoch_timer = timer()

        # Train and validate
        train(args, model, device, train_loader, optimizer, epoch, rnn)
        val_loss, eer = validation(args, model, device, validation_loader,
                                   args.validation_utt2label, rnn)
        scheduler.step(val_loss)
        # Save
        if select_best == 'eer':
            is_best = eer < best_eer
            best_eer = min(eer, best_eer)
            best_model = model
        elif select_best == 'val':
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
            best_model = model
        snapshot(
            args.logging_dir, run_name, is_best, {
                'epoch': epoch + 1,
                'best_eer': best_eer,
                'state_dict': model.state_dict(),
                'validation_loss': val_loss,
                'optimizer': optimizer.state_dict(),
            })
        # Early stopping
        if is_best == 1:
            early_stopping = 0
        else:
            early_stopping += 1
        end_epoch_timer = timer()
        logger.info("#### End epoch {}/{}, elapsed time: {}".format(
            epoch, args.epochs, end_epoch_timer - epoch_timer))  # noqa
        total_train_time.append(end_epoch_timer - epoch_timer)
        if early_stopping == max_patience:
            break
    logger.info("#### Avg. training+validation time per epoch: {}".format(
        np.average(total_train_time)))  # noqa
    ###########################################################
    logger.info("#### fine-tuned decomp model size (MB): {}".format(
        get_size_of_model(best_model)))
    model_params = sum(p.numel() for p in best_model.parameters()
                       if p.requires_grad)
    logger.info(
        '#### non-zero params after fine-tuning: {}'.format(model_params))
    logger.info(
        "################## Done fine-tuning decomp model ######################"
    )
Esempio n. 2
0
def main():
    ##############################################################
    ## Settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--eval-scp', help='kaldi eval scp file')
    parser.add_argument('--eval-utt2label', help='train utt2label')
    parser.add_argument('--model-path', help='trained model')
    parser.add_argument('--logging-dir',
                        required=True,
                        help='model save directory')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='input batch size for testing (default: 100)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--plot-dir', help='directory to save plots')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print('use_cuda is', use_cuda)

    # Global timer
    global_timer = timer()

    # Setup logs
    logger = setup_logs(args.logging_dir, run_name)

    # Setting random seeds for reproducibility.
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    model.to(device)
    ##############################################################
    ## Loading the dataset
    params = {'num_workers': 0, 'pin_memory': False} if use_cuda else {}

    logger.info('===> loading eval dataset')
    eval_set = SpoofDataset(args.eval_scp, args.eval_utt2label)
    eval_loader = data.DataLoader(eval_set,
                                  batch_size=args.test_batch_size,
                                  shuffle=False,
                                  **params)  # set shuffle to False
    ################### for multiple models #####################
    np.set_printoptions(threshold=np.nan)
    sum_preds = 0
    for model_i in models:
        logger.info('===> loading {} for prediction'.format(model_i))
        checkpoint = torch.load(model_i)
        model.load_state_dict(checkpoint['state_dict'])
        model_params = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
        print('model params is', model_params)

        retrieve_weight(args, model, device, eval_loader, args.eval_scp,
                        args.eval_utt2label, args.plot_dir, rnn)
    logger.info("===> Final predictions done. Here is a snippet")
    ###########################################################
    end_global_timer = timer()
    logger.info("################## Success #########################")
    logger.info("Total elapsed time: %s" % (end_global_timer - global_timer))
def main():
    ##############################################################
    # Settings
    parser = argparse.ArgumentParser(description='Model AFN')
    parser.add_argument('--train-dir',
                        help='train feature dir')
    parser.add_argument('--train-utt2label',
                        help='train utt2label')
    parser.add_argument('--eval-dir',
                        help='eval feature dir')
    parser.add_argument('--eval-utt2label',
                        help='eval utt2label')
    parser.add_argument('--model-path',
                        help='path to the pretrained model')
    parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
                        help='input batch size for testing (default: 100)')
    parser.add_argument('--logging-dir', required=True,
                        help='model save directory')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--seg', default=None,
                        help='seg method')
    parser.add_argument('--seg-win', type=int,
                        help='segmented window size')
    parser.add_argument('--quant-method', required=True,
                        help='quantization method: dynamic static')
    args = parser.parse_args()

    torch.cuda.empty_cache()

    # Init model & Setup logs
    if args.seg is None:
        model = AttenResNet4(atten_activation, atten_channel, size1=(257, M), static_quant=True)
        run_name = "quant_pred-AFN4-1091-orig" + time.strftime("-%Y_%m_%d")
    else:
        if args.seg_win not in (64, 128, 256, 512):
            raise ValueError("Invalid segment window! Must be 64, 128, 256, or 512")
        model = AttenResNet4DeformAll(atten_activation, atten_channel, size1=(257, args.seg_win))  # noqa
        run_name = "quant_pred-AFN4De-" + str(args.seg_win) + "-" + args.seg + time.strftime("-%Y_%m_%d")  # noqa
    logger = setup_logs(args.logging_dir, run_name)

    logger.info("use_cuda is False. Only runnable on CPU!")

    # Setting random seeds for reproducibility.
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True  # CUDA determinism

    device = torch.device('cpu')
    model.to(device)

    ##############################################################
    # Loading the dataset & the best model
    params = {}

    logger.info('===> loading eval dataset: ' + args.eval_utt2label)
    train_set = SpoofDataset(args.train_dir, args.train_utt2label)
    train_loader = data.DataLoader(
        train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        **params
    )  # set shuffle to True
    eval_set = SpoofDataset(args.eval_dir, args.eval_utt2label)
    eval_loader = data.DataLoader(
        eval_set,
        batch_size=TEST_BATCH_SIZE,
        shuffle=False,
        **params
    )  # set shuffle to False
    logger.info('===> loading best model for prediction: ' + args.model_path)
    checkpoint = torch.load(
        os.path.join(args.model_path),
        map_location=device
    )
    model.load_state_dict(checkpoint['state_dict'])

    ##############################################################
    # apply network quantization & prediction
    model = apply_net_quant(
        model=model,
        logger=logger,
        quant_method=args.quant_method,
        calibration_loader=train_loader
    )

    t_start_eval = timer()
    eval_loss, eval_eer = prediction(args, model, device, eval_loader, args.eval_utt2label, rnn)  # noqa
    t_end_eval = timer()
    logger.info("#### Total prediction time: {}".format(t_end_eval - t_start_eval))
    ###########################################################
    logger.info("################## Success #########################\n\n")
Esempio n. 4
0
def main():
    ##############################################################
    # Settings
    parser = argparse.ArgumentParser(description='Model AFN')
    parser.add_argument('--eval-dir',
                        help='eval feature dir')
    parser.add_argument('--eval-utt2label',
                        help='train utt2label')
    parser.add_argument('--model-path',
                        help='path to the pretrained model')
    parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
                        help='input batch size for testing (default: 100)')
    parser.add_argument('--logging-dir', required=True,
                        help='model save directory')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--seg', default=None,
                        help='seg method')
    parser.add_argument('--seg-win', type=int,
                        help='segmented window size')
    parser.add_argument('--decomp-rank', type=int,
                        help='rank value for decomposition')
    parser.add_argument('--decomp-type', default='tucker',
                        help='decomposition type')

    args = parser.parse_args()

    torch.cuda.empty_cache()

    # Init model & Setup logs
    if args.seg is None:
        model = AttenResNet4(atten_activation, atten_channel, size1=(257, M))
        run_name = "decomp_pred-AFN4-1091-orig" + time.strftime("-%Y_%m_%d")
    else:
        if args.seg_win not in (64, 128, 256, 512):
            raise ValueError("Invalid segment window! Must be 64, 128, 256, or 512")
        model = AttenResNet4DeformAll(atten_activation, atten_channel, size1=(257, args.seg_win))  # noqa
        run_name = "decomp_pred-AFN4De-" + str(args.seg_win) + "-" + args.seg + time.strftime("-%Y_%m_%d")  # noqa
    logger = setup_logs(args.logging_dir, run_name)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    logger.info("use_cuda is {}".format(use_cuda))

    # Setting random seeds for reproducibility.
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True  # CUDA determinism

    device = torch.device("cuda" if use_cuda else "cpu")
    model.to(device)

    ##############################################################
    # Loading the dataset
    params = {'num_workers': 0,
              'pin_memory': False,
              'worker_init_fn': np.random.seed(args.seed)} if use_cuda else {}

    logger.info('===> loading eval dataset: ' + args.eval_utt2label)
    eval_set = SpoofDataset(args.eval_dir, args.eval_utt2label)
    eval_loader = data.DataLoader(
        eval_set,
        batch_size=TEST_BATCH_SIZE,
        shuffle=False,
        **params
    )  # set shuffle to False

    ##############################################################
    # apply network decomposition and load fine-tuned weights
    model.to('cpu')
    model = apply_net_decomp(
        model=model,
        logger=logger,
        rank=args.decomp_rank,
        decomp_type=args.decomp_type
    )
    logger.info('===> loading fine-tuned model for prediction: ' + args.model_path)
    checkpoint = torch.load(
        os.path.join(args.model_path),
        map_location=device
    )
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    t_start_eval = timer()
    eval_loss, eval_eer = prediction(args, model, device, eval_loader, args.eval_utt2label, rnn)  # noqa
    t_end_eval = timer()
    logger.info("#### Total prediction time: {}".format(t_end_eval - t_start_eval))
    ###########################################################
    logger.info("################## Success #########################\n\n")
def main():
    parser = argparse.ArgumentParser(
        description='Feature (log-spec) Extraction')
    parser.add_argument('--data-dir',
                        required=True,
                        help='data directory contains wave files')
    parser.add_argument('--label-file',
                        required=True,
                        help='protocol file that contains utt2label mapping')
    parser.add_argument('--feat-dir',
                        required=True,
                        help='feature saving directory')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA for feature extraction')
    parser.add_argument('--logging-dir',
                        required=True,
                        help='log save directory')
    parser.add_argument('--segment',
                        action='store_true',
                        default=False,
                        help='whether to segment the logsepc by energy')
    parser.add_argument(
        '--seg-win',
        type=int,
        help='the window size to be used for segment: 64, 128..')
    parser.add_argument('--seg-method',
                        help='the method to be used for segment: h, l, hl, lh')
    args = parser.parse_args()

    os.makedirs(args.logging_dir, exist_ok=True)
    os.makedirs(args.feat_dir, exist_ok=True)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if args.segment and (args.seg_method is None or args.seg_win is None):
        raise ValueError("segment method or win_size is missing")

    # Setup logs
    basename = DATA_NAME + os.path.basename(args.data_dir)
    run_name = "logspec-" + basename + time.strftime("-%Y-%m-%d")
    if os.path.exists(run_name + ".log"):
        os.remove(run_name + ".log")
    logger = setup_logs(args.logging_dir, run_name)

    logger.info("===> Start logspec extraction for dataset: " + args.data_dir)

    # global timer starts
    global_start = timer()

    utt2label_path = os.path.join(
        os.path.dirname(args.label_file),
        os.path.basename(args.label_file) + '.utt2label')

    f_utt2label = open(utt2label_path, 'w')
    f_label = open(args.label_file, 'r')

    for line in f_label:
        item = line.strip().split(' ')
        if item[1] == 'genuine':
            label = 1
        elif item[1] == 'spoof':
            label = 0
        else:
            raise ValueError("Invalid label: " + item[1])
        f_utt2label.write(item[0][:-4] + ' ' + str(label) + '\n')

        audio_path = os.path.join(args.data_dir, item[0])

        t_start = timer()
        logspec = get_logspec(audio_path, device)
        if args.segment:
            if args.seg_method == DEFAULT_SEG:
                feat = expand_logspec(logspec, M=args.seg_win)
            elif args.seg_method == TAIL_SEG:
                feat = torch.flipud(
                    expand_logspec(torch.flipud(logspec), M=args.seg_win))
            else:
                feat = segment_logspec(logspec, args.seg_win, args.seg_method)
        else:
            feat = expand_logspec(logspec)
        t_end = timer()
        logger.info(item[0] + "\tfeature extraction time: %s" %
                    (t_end - t_start))

        f_feat_path = os.path.join(args.feat_dir, item[0][:-4] + '.pt')
        torch.save(feat, f_feat_path)

    f_label.close()
    f_utt2label.close()

    global_end = timer()
    logger.info("#### Done logspec extraction for dataset: " + args.data_dir +
                "####")
    logger.info("Total elapsed time: %s" % (global_end - global_start))
Esempio n. 6
0
def main():
    ##############################################################
    # Settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--train-scp',
                        required=True,
                        help='kaldi train scp file')
    parser.add_argument('--train-utt2label',
                        required=True,
                        help='train utt2label')
    parser.add_argument('--validation-scp',
                        required=True,
                        help='kaldi dev scp file')
    parser.add_argument('--validation-utt2label',
                        required=True,
                        help='dev utt2label')
    parser.add_argument('--eval-scp', help='kaldi eval scp file')
    parser.add_argument('--eval-utt2label', help='eval utt2label')
    parser.add_argument('--logging-dir',
                        required=True,
                        help='model save directory')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval',
                        type=int,
                        default=10,
                        metavar='N',
                        help='how many batches to wait before logging '
                        'training status')
    parser.add_argument('--hidden-dim',
                        type=int,
                        default=100,
                        help='number of neurones in the hidden dimension')
    parser.add_argument('--plot-wd', help='training plot directory')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print('use_cuda is', use_cuda)
    # print('temperature is', temperature)

    # Global timer
    global_timer = timer()

    # Setup logs
    logger = setup_logs(args.logging_dir, run_name)

    # Setting random seeds for reproducibility.
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True  # CUDA determinism

    device = torch.device("cuda:0" if use_cuda else "cpu")
    model.to(device)
    ##############################################################
    # Loading the dataset
    params = {
        'num_workers': 16,
        'pin_memory': True,
        'worker_init_fn': np.random.seed(args.seed)
    } if use_cuda else {}

    logger.info('===> loading train and dev dataset')
    training_set = SpoofDataset(args.train_scp, args.train_utt2label)
    validation_set = SpoofDataset(args.validation_scp,
                                  args.validation_utt2label)
    train_loader = data.DataLoader(training_set,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   **params)  # set shuffle to True
    validation_loader = data.DataLoader(validation_set,
                                        batch_size=test_batch_size,
                                        shuffle=False,
                                        **params)  # set shuffle to False

    logger.info('===> loading eval dataset')
    eval_set = SpoofDataset(args.eval_scp, args.eval_utt2label)
    eval_loader = data.DataLoader(eval_set,
                                  batch_size=test_batch_size,
                                  shuffle=False,
                                  **params)  # set shuffle to False

    optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     factor=0.5,
                                                     patience=1)

    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info('### Model summary below###\n {}\n'.format(str(model)))
    logger.info('===> Model total parameter: {}\n'.format(model_params))
    ###########################################################
    # Start training
    best_eer, best_loss = np.inf, np.inf
    early_stopping, max_patience = 0, 5  # early stopping and maximum patience
    print(run_name)
    for epoch in range(1, args.epochs + 1):
        epoch_timer = timer()

        # Train and validate
        train(args, model, device, train_loader, optimizer, epoch, rnn)
        # train(args, model, device, train_loader, optimizer, epoch,
        #       args.train_scp, args.train_utt2label, args.plot_wd, rnn=False)
        val_loss, eer = validation(args, model, device, validation_loader,
                                   args.validation_scp,
                                   args.validation_utt2label, rnn)
        scheduler.step(val_loss)
        # Save
        if select_best == 'eer':
            is_best = eer < best_eer
            best_eer = min(eer, best_eer)
        elif select_best == 'val':
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
        snapshot(
            args.logging_dir, run_name, is_best, {
                'epoch': epoch + 1,
                'best_eer': best_eer,
                'state_dict': model.state_dict(),
                'validation_loss': val_loss,
                'optimizer': optimizer.state_dict(),
            })
        # Early stopping
        if is_best == 1:
            early_stopping = 0
        else:
            early_stopping += 1
        end_epoch_timer = timer()
        logger.info("#### End epoch {}/{}, "
                    "elapsed time: {}".format(epoch, args.epochs,
                                              end_epoch_timer - epoch_timer))
        if early_stopping == max_patience:
            break
    ###########################################################
    # Prediction
    logger.info('===> loading best model for prediction')
    checkpoint = torch.load(
        os.path.join(args.logging_dir, run_name + '-model_best.pth'))
    model.load_state_dict(checkpoint['state_dict'])

    eval_loss, eval_eer = prediction(args, model, device, eval_loader,
                                     args.eval_scp, args.eval_utt2label, rnn)
    ###########################################################
    end_global_timer = timer()
    logger.info("################## Success #########################")
    logger.info("Total elapsed time: %s" % (end_global_timer - global_timer))
def main():
    ##############################################################
    # Settings
    parser = argparse.ArgumentParser(description='Model AFN')
    parser.add_argument('--train-dir', required=True, help='train feature dir')
    parser.add_argument('--train-utt2label',
                        required=True,
                        help='train utt2label')
    parser.add_argument('--validation-dir',
                        required=True,
                        help='dev feature dir')
    parser.add_argument('--validation-utt2label',
                        required=True,
                        help='dev utt2label')
    parser.add_argument('--eval-dir', help='eval feature dir')
    parser.add_argument('--eval-utt2label', help='train utt2label')
    parser.add_argument('--logging-dir',
                        required=True,
                        help='model save directory')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--seg', default=None, help='seg method')
    parser.add_argument('--seg-win', type=int, help='segmented window size')
    args = parser.parse_args()

    torch.cuda.empty_cache()

    # Init model & Setup logs
    if args.seg is None:
        model = AttenResNet4(atten_activation, atten_channel, size1=(257, M))
        run_name = "AFN4" + time.strftime("-%Y_%m_%d-%H_%M_%S-") + str(
            M) + "-orig"
    else:
        model = AttenResNet4DeformAll(atten_activation,
                                      atten_channel,
                                      size1=(257, args.seg_win))  # noqa
        run_name = "AFN4De" + time.strftime("-%Y_%m_%d-%H_%M_%S-") + str(
            args.seg_win) + "-" + args.seg  # noqa
    logger = setup_logs(args.logging_dir, run_name)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    logger.info("use_cuda is {}".format(use_cuda))

    # Global timer
    global_timer = timer()

    # Setting random seeds for reproducibility.
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True  # CUDA determinism

    device = torch.device("cuda" if use_cuda else "cpu")
    model.to(device)
    ##############################################################
    # Loading the dataset
    params = {
        'num_workers': 0,
        'pin_memory': False,
        'worker_init_fn': np.random.seed(args.seed)
    } if use_cuda else {}

    logger.info('===> loading train and dev dataset')
    training_set = SpoofDataset(args.train_dir, args.train_utt2label)
    validation_set = SpoofDataset(args.validation_dir,
                                  args.validation_utt2label)
    train_loader = data.DataLoader(training_set,
                                   batch_size=BATCH_SIZE,
                                   shuffle=True,
                                   **params)  # set shuffle to True
    validation_loader = data.DataLoader(validation_set,
                                        batch_size=TEST_BATCH_SIZE,
                                        shuffle=False,
                                        **params)  # set shuffle to False

    optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     factor=0.5,
                                                     patience=1)  # noqa

    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info('#### Model summary below ####\n {}\n'.format(str(model)))
    logger.info('===> Model total # parameter: {}\n'.format(model_params))
    ###########################################################
    # Training
    best_eer, best_loss = np.inf, np.inf
    early_stopping, max_patience = 0, 5  # early stopping and maximum patience

    total_train_time = []
    for epoch in range(1, args.epochs + 1):
        epoch_timer = timer()

        # Train and validate
        train(args, model, device, train_loader, optimizer, epoch, rnn)
        val_loss, eer = validation(args, model, device, validation_loader,
                                   args.validation_utt2label, rnn)
        scheduler.step(val_loss)
        # Save
        if SELECT_BEST == 'eer':
            is_best = eer < best_eer
            best_eer = min(eer, best_eer)
        elif SELECT_BEST == 'val':
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
        snapshot(
            args.logging_dir, run_name, is_best, {
                'epoch': epoch + 1,
                'best_eer': best_eer,
                'state_dict': model.state_dict(),
                'validation_loss': val_loss,
                'optimizer': optimizer.state_dict(),
            })
        # Early stopping
        if is_best == 1:
            early_stopping = 0
        else:
            early_stopping += 1
        end_epoch_timer = timer()
        logger.info("#### End epoch {}/{}, elapsed time: {}".format(
            epoch, args.epochs, end_epoch_timer - epoch_timer))  # noqa
        total_train_time.append(end_epoch_timer - epoch_timer)
        if early_stopping == max_patience:
            break
    logger.info("#### Avg. training+validation time per epoch: {}".format(
        np.average(total_train_time)))  # noqa
    ###########################################################
    # Prediction
    if args.eval_dir and args.eval_utt2label:
        logger.info('===> loading eval dataset')
        eval_set = SpoofDataset(args.eval_dir, args.eval_utt2label)
        eval_loader = data.DataLoader(eval_set,
                                      batch_size=TEST_BATCH_SIZE,
                                      shuffle=False,
                                      **params)  # set shuffle to False

        logger.info('===> loading best model for prediction')
        checkpoint = torch.load(
            os.path.join(args.logging_dir, run_name + '-model_best.pth'))
        model.load_state_dict(checkpoint['state_dict'])
        t_start_eval = timer()
        eval_loss, eval_eer = prediction(args, model, device, eval_loader,
                                         args.eval_utt2label, rnn)  # noqa
        end_global_timer = timer()
        logger.info(
            "#### Total prediction time: {}".format(end_global_timer -
                                                    t_start_eval))  # noqa
    ###########################################################
    logger.info("################## Success #########################")
    logger.info("Total elapsed time: %s" % (end_global_timer - global_timer))