Esempio n. 1
0
def main(argv):

    args = parse_args(argv)

    if args.path_checkpoint is None:
        if args.file_extension == '.pt':
            feature_function = load_pt
        elif args.file_extension == '.npy':
            feature_function = load_npy
    else:
        state_dict = torch.load(args.path_checkpoint)
        feature_maker = load_cpc_features(state_dict)
        feature_maker.cuda()

        def feature_function(x):
            return build_feature_from_file(x, feature_maker)

    # Modes
    if args.mode == 'all':
        modes = ["within", "across"]
    else:
        modes = [args.mode]

    step_feature = 1 / args.feature_size

    # Get the list of sequences
    seq_list = find_all_files(args.path_data, args.file_extension)

    scores = ABX(feature_function,
                 args.path_item_file,
                 seq_list,
                 args.distance_mode,
                 step_feature,
                 modes,
                 cuda=args.cuda,
                 max_x_across=args.max_x_across,
                 max_size_group=args.max_size_group)

    out_dir = Path(args.path_checkpoint).parent if args.out is None \
        else Path(args.out)
    out_dir.mkdir(exist_ok=True)

    path_score = out_dir / 'ABX_scores.json'
    with open(path_score, 'w') as file:
        json.dump(scores, file, indent=2)

    path_args = out_dir / 'ABX_args.json'
    with open(path_args, 'w') as file:
        json.dump(vars(args), file, indent=2)
Esempio n. 2
0
def per(args):

    # Load the model
    state_dict = torch.load(args.path_checkpoint)
    feature_maker = load_cpc_features(state_dict["model"])
    feature_maker.cuda()
    feature_maker.eval()
    hidden_gar = feature_maker.get_output_dim()

    # Get the model training configuration
    path_config = Path(args.path_checkpoint).parent / "args_training.json"
    with open(path_config, 'rb') as file:
        config_training = json.load(file)

    n_phones = get_n_phones(config_training["path_phone_converter"])
    phone_criterion = per_src.CTCPhoneCriterion(
        hidden_gar,
        n_phones,
        config_training["LSTM"],
        seqNorm=config_training["seqNorm"],
        dropout=config_training["dropout"],
        reduction=config_training["loss_reduction"])
    phone_criterion.load_state_dict(state_dict["classifier"])
    phone_criterion.cuda()
    downsamplingFactor = 160

    # dataset
    inSeqs = find_all_files(args.pathDB, args.file_extension)
    phoneLabels = parse_phone_labels(args.pathPhone)

    datasetVal = per_src.SingleSequenceDataset(args.pathDB,
                                               inSeqs,
                                               phoneLabels,
                                               inDim=1)
    valLoader = DataLoader(datasetVal,
                           batch_size=args.batchSize,
                           shuffle=False)

    per_step(valLoader, feature_maker, phone_criterion, downsamplingFactor)
Esempio n. 3
0
def main(args):
    description = "Training and evaluation of a letter classifier on top of a pre-trained CPC model. "
    "Please specify at least one `path_wer` (to calculate WER) or `path_train` and `path_val` (for training)."

    parser = argparse.ArgumentParser(description=description)

    parser.add_argument('--path_checkpoint', type=str)
    parser.add_argument('--path_train', default=None, type=str)
    parser.add_argument('--path_val', default=None, type=str)
    parser.add_argument('--n_epochs', type=int, default=30)
    parser.add_argument('--seed', type=int, default=7)
    parser.add_argument('--downsampling_factor', type=int, default=160)

    parser.add_argument('--lr', type=float, default=2e-04)
    parser.add_argument('--output',
                        type=str,
                        default='out',
                        help="Output directory")
    parser.add_argument('--p_dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=32)

    parser.add_argument('--lm_weight', type=float, default=2.0)
    parser.add_argument('--path_wer',
                        help="For computing the WER on specific sequences",
                        action='append')
    parser.add_argument('--letters_path',
                        type=str,
                        default='WER_data/letters.lst')

    args = parser.parse_args(args=args)

    if not args.path_wer and not (args.path_train and args.path_val):
        print(
            'Please specify at least one `path_wer` (to calculate WER) or `path_train` and `path_val` (for training).'
        )

    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    # creating models before reading the datasets
    with open(args.letters_path) as f:
        n_chars = len(f.readlines())

    state_dict = torch.load(args.path_checkpoint)
    feature_maker = load_cpc_features(state_dict)
    feature_maker.cuda()
    hidden = feature_maker.get_output_dim()

    letter_classifier = LetterClassifier(
        feature_maker,
        hidden,
        n_chars,
        p_dropout=args.p_dropout if hasattr(args, 'p_dropout') else 0.0)

    criterion = CTCLetterCriterion(letter_classifier, n_chars)
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion)

    # Checkpoint file where the model should be saved
    path_checkpoint = os.path.join(args.output, 'checkpoint.pt')

    if args.path_train and args.path_val:
        set_seed(args.seed)

        char_labels_val, n_chars, _ = parse_ctc_labels_from_root(
            args.path_val, letters_path="./WER_data/letters.lst")
        print(f"Loading the validation dataset at {args.path_val}")
        dataset_val = SingleSequenceDataset(args.path_val, char_labels_val)
        val_loader = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=False)

        # train dataset
        char_labels_train, n_chars, _ = parse_ctc_labels_from_root(
            args.path_train, letters_path="./WER_data/letters.lst")

        print(f"Loading the training dataset at {args.path_train}")
        dataset_train = SingleSequenceDataset(args.path_train,
                                              char_labels_train)
        train_loader = DataLoader(dataset_train,
                                  batch_size=args.batch_size,
                                  shuffle=True)

        # Optimizer
        g_params = list(criterion.parameters())
        optimizer = torch.optim.Adam(g_params, lr=args.lr)

        args_path = os.path.join(args.output, "args_training.json")
        with open(args_path, 'w') as file:
            json.dump(vars(args), file, indent=2)

        run(train_loader, val_loader, criterion, optimizer,
            args.downsampling_factor, args.n_epochs, path_checkpoint)

    if args.path_wer:
        args = get_eval_args(args)

        state_dict = torch.load(path_checkpoint)
        criterion.load_state_dict(state_dict)
        criterion = criterion.module
        criterion.eval()

        args_path = os.path.join(args.output, "args_validation.json")
        with open(args_path, 'w') as file:
            json.dump(vars(args), file, indent=2)

        for path_wer in args.path_wer:
            print(f"Loading the validation dataset at {path_wer}")

            char_labels_wer, _, (letter2index,
                                 index2letter) = parse_ctc_labels_from_root(
                                     path_wer,
                                     letters_path="./WER_data/letters.lst")
            dataset_eval = SingleSequenceDataset(path_wer, char_labels_wer)
            eval_loader = DataLoader(dataset_eval,
                                     batch_size=args.batch_size,
                                     shuffle=False)

            wer = eval_wer(eval_loader, criterion, args.lm_weight,
                           index2letter)
            print(f'WER: {wer}')
Esempio n. 4
0
def train(args):

    # Output Directory
    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    name = f"_{args.name}" if args.command == "per" else ""
    pathLogs = os.path.join(args.output, f'logs_{args.command}{name}.txt')
    tee = subprocess.Popen(["tee", pathLogs], stdin=subprocess.PIPE)
    os.dup2(tee.stdin.fileno(), sys.stdout.fileno())

    nPhones = get_n_phones(args.path_phone_converter)
    phoneLabels = parse_phone_labels(args.pathPhone)
    inSeqs = find_all_files(args.pathDB, args.file_extension)

    # Model
    downsamplingFactor = 160
    state_dict = torch.load(args.pathCheckpoint)
    featureMaker = load_cpc_features(state_dict)
    hiddenGar = featureMaker.get_output_dim()
    featureMaker.cuda()
    featureMaker = torch.nn.DataParallel(featureMaker)

    # Criterion
    phoneCriterion = per_src.CTCPhoneCriterion(hiddenGar,
                                               nPhones,
                                               args.LSTM,
                                               seqNorm=args.seqNorm,
                                               dropout=args.dropout,
                                               reduction=args.loss_reduction)
    phoneCriterion.cuda()
    phoneCriterion = torch.nn.DataParallel(phoneCriterion)

    # Datasets
    if args.command == 'train' and args.pathTrain is not None:
        seqTrain = filter_seq(args.pathTrain, inSeqs)
    else:
        seqTrain = inSeqs

    if args.pathVal is None:
        random.shuffle(seqTrain)
        sizeTrain = int(0.9 * len(seqTrain))
        seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
    elif args.pathVal is not None:
        seqVal = filter_seq(args.pathVal, inSeqs)
        print(len(seqVal), len(inSeqs), args.pathVal)

    if args.debug:
        seqVal = seqVal[:100]

    print(f"Loading the validation dataset at {args.pathDB}")
    datasetVal = per_src.SingleSequenceDataset(args.pathDB,
                                               seqVal,
                                               phoneLabels,
                                               inDim=args.in_dim)

    valLoader = DataLoader(datasetVal, batch_size=args.batchSize, shuffle=True)

    # Checkpoint file where the model should be saved
    pathCheckpoint = os.path.join(args.output, 'checkpoint.pt')

    featureMaker.optimize = True
    if args.freeze:
        featureMaker.eval()
        featureMaker.optimize = False
        for g in featureMaker.parameters():
            g.requires_grad = False

    if args.debug:
        print("debug")
        random.shuffle(seqTrain)
        seqTrain = seqTrain[:1000]
        seqVal = seqVal[:100]

    print(f"Loading the training dataset at {args.pathDB}")
    datasetTrain = per_src.SingleSequenceDataset(args.pathDB,
                                                 seqTrain,
                                                 phoneLabels,
                                                 inDim=args.in_dim)

    trainLoader = DataLoader(datasetTrain,
                             batch_size=args.batchSize,
                             shuffle=True)

    # Optimizer
    g_params = list(phoneCriterion.parameters())
    if not args.freeze:
        print("Optimizing model")
        g_params += list(featureMaker.parameters())

    optimizer = torch.optim.Adam(g_params,
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    pathArgs = os.path.join(args.output, "args_training.json")
    with open(pathArgs, 'w') as file:
        json.dump(vars(args), file, indent=2)

    run_training(trainLoader, valLoader, featureMaker, phoneCriterion,
                 optimizer, downsamplingFactor, args.nEpochs, pathCheckpoint)