Exemple #1
0
def main(argv):

    args = parse_args(argv)

    if args.load == 'from_checkpoint':
        # Checkpoint
        model = loadModel([args.path_checkpoint])[0]
        model.gAR.keepHidden = True
        # Feature maker
        feature_maker = FeatureModule(model, args.get_encoded).cuda().eval()

        def feature_function(x): return buildFeature(feature_maker, x,
                                                     seqNorm=args.seq_norm,
                                                     strict=args.strict,
                                                     maxSizeSeq=args.max_size_seq)
    elif args.load == 'from_pre_computed':
        def feature_function(x): return torch.load(x, 'cpu')

    # Modes
    if args.mode == 'all':
        modes = ["within", "across"]
    else:
        modes = [args.mode]

    distance_mode = 'cosine'

    step_feature = 1 / args.feature_size

    # Get the list of sequences
    seq_list, _ = findAllSeqs(args.path_dataset, extension=args.file_extension)
    seq_list = [(str(Path(x).stem), str(Path(args.path_dataset) / x))
                for (_, x) in seq_list]

    if args.debug:
        seq_list = seq_list[:1000]

    scores = ABX(feature_function, args.path_item_file,
                 seq_list, distance_mode,
                 step_feature, modes,
                 cuda=args.cuda,
                 seq_norm=args.seq_norm,
                 max_x_across=args.max_x_across,
                 max_size_group=args.max_size_group)

    out_dir = Path(args.path_checkpoint).parent if args.out is None \
        else Path(args.out)
    out_dir.mkdir(exist_ok=True)

    path_score = out_dir / 'ABX_scores.json'
    with open(path_score, 'w') as file:
        json.dump(scores, file, indent=2)

    path_args = out_dir / 'ABX_args.json'
    with open(path_args, 'w') as file:
        json.dump(vars(args), file, indent=2)
def main(argv):

    args = parse_args(argv)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    load_criterion = False

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    if args.model == "cpc":
        def loadCPCFeatureMaker(pathCheckpoint, gru_level=-1, get_encoded=False, keep_hidden=True):
            """
            Load CPC Feature Maker from CPC checkpoint file.
            """
            # Set LSTM level
            if gru_level is not None and gru_level > 0:
                updateConfig = argparse.Namespace(nLevelsGRU=gru_level)
            else:
                updateConfig = None

            # Load CPC model
            model, nHiddenGar, nHiddenEncoder = fl.loadModel(pathCheckpoint, updateConfig=updateConfig)
            
            # Keep hidden units at LSTM layers on sequential batches
            model.gAR.keepHidden = keep_hidden

            # Build CPC Feature Maker from CPC model
            #featureMaker = fl.FeatureModule(model, get_encoded=get_encoded)

            #return featureMaker
            return model, nHiddenGar, nHiddenEncoder

        if args.gru_level is not None and args.gru_level > 0:
            model, hidden_gar, hidden_encoder = loadCPCFeatureMaker(args.load, gru_level=args.gru_level)
        else:
            model, hidden_gar, hidden_encoder = fl.loadModel(args.load,
                                                     loadStateDict=not args.no_pretraining)

        dim_features = hidden_encoder if args.get_encoded else hidden_gar
    else:
        sys.path.append(os.path.abspath(args.path_fairseq))
        from fairseq import checkpoint_utils

        def loadCheckpoint(path_checkpoint, path_data):
            """
            Load lstm_lm model from checkpoint.
            """
            # Set up the args Namespace
            model_args = argparse.Namespace(
                task="language_modeling",
                output_dictionary_size=-1,
                data=path_data,
                path=path_checkpoint
                )
            
            # Load model
            models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path])
            model = models[0]
            return model

        model = loadCheckpoint(args.load[0], args.pathDB)
        dim_features = 768

    dim_inter = args.dim_inter
    # Now the criterion

    if args.mode == "phonemes_nullspace" or args.mode == "speakers_nullspace":
        speakers_factorized = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers))
        speakers_factorized.load_state_dict(torch.load(args.path_speakers_factorized)["cpcCriterion"])
        for param in speakers_factorized.parameters():
            param.requires_grad = False

        def my_nullspace(At, rcond=None):
            ut, st, vht = torch.Tensor.svd(At, some=False,compute_uv=True)
            vht=vht.T        
            Mt, Nt = ut.shape[0], vht.shape[1] 
            if rcond is None:
                rcondt = torch.finfo(st.dtype).eps * max(Mt, Nt)
            tolt = torch.max(st) * rcondt
            numt= torch.sum(st > tolt, dtype=int)
            nullspace = vht[numt:,:].T.cpu().conj()
            # nullspace.backward(torch.ones_like(nullspace),retain_graph=True)
            return nullspace

        dim_features = dim_features - dim_inter
        nullspace = my_nullspace(speakers_factorized.linearSpeakerClassifier[0].weight)
        model = CPCModelNullspace(model, nullspace)

    phone_labels = None
    if args.pathPhone is not None:

        phone_labels, n_phones = parseSeqLabels(args.pathPhone)
        label_key = 'phone'

        if not args.CTC:
            print(f"Running phone separability with aligned phones")
            criterion = cr.PhoneCriterion(dim_features,
                                          n_phones, args.get_encoded)
        else:
            print(f"Running phone separability with CTC loss")
            criterion = cr.CTCPhoneCriterion(dim_features,
                                             n_phones, args.get_encoded)
    else:
        label_key = 'speaker'
        print(f"Running speaker separability")
        if args.mode == "speakers_factorized":
            criterion = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers))
        else:
            criterion = cr.SpeakerCriterion(dim_features, len(speakers))
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU))

    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(args.nGPU))

    # Dataset
    seq_train = filterSeqs(args.pathTrain, seqNames)
    seq_val = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seq_train = seq_train[:1000]
        seq_val = seq_val[:100]

    db_train = AudioBatchData(args.pathDB, args.size_window, seq_train,
                              phone_labels, len(speakers), nProcessLoader=args.n_process_loader,
                                  MAX_SIZE_LOADED=args.max_size_loaded)
    db_val = AudioBatchData(args.pathDB, args.size_window, seq_val,
                            phone_labels, len(speakers), nProcessLoader=args.n_process_loader)

    batch_size = args.batchSizeGPU * args.nGPU

    train_loader = db_train.getDataLoader(batch_size, "uniform", True,
                                          numWorkers=0)

    val_loader = db_val.getDataLoader(batch_size, 'sequential', False,
                                      numWorkers=0)

    # Optimizer
    g_params = list(criterion.parameters())
    model.optimize = False
    model.eval()
    if args.unfrozen:
        print("Working in full fine-tune mode")
        g_params += list(model.parameters())
        model.optimize = True
    else:
        print("Working with frozen features")
        for g in model.parameters():
            g.requires_grad = False

    optimizer = torch.optim.Adam(g_params, lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    # Checkpoint directory
    args.pathCheckpoint = Path(args.pathCheckpoint)
    args.pathCheckpoint.mkdir(exist_ok=True)
    args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint")

    with open(f"{args.pathCheckpoint}_args.json", 'w') as file:
        json.dump(vars(args), file, indent=2)

    if args.centerpushFile:
        clustersFileExt = args.centerpushFile.split('.')[-1]
        assert clustersFileExt in ('pt', 'npy', 'txt')
        if clustersFileExt == 'npy':
            centers = np.load(args.centerpushFile)
        elif clustersFileExt == 'txt':
            centers = np.genfromtxt(args.centerpushFile)
        elif clustersFileExt == 'pt':  # assuming it's a checkpoint
            centers = torch.load(args.centerpushFile, map_location=torch.device('cpu'))['state_dict']['Ck']
            centers = torch.reshape(centers, centers.shape[1:]).numpy()
        centers = torch.tensor(centers).cuda()
        centerpushSettings = (centers, args.centerpushDeg)
    else:
        centerpushSettings = None

    run(model, criterion, train_loader, val_loader, optimizer, logs,
        args.n_epoch, args.pathCheckpoint, label_key=label_key, centerpushSettings=centerpushSettings)
def main(pathClusteringCheckpoint,
         pathActivations,
         pathOutputDir,
         cpu=False,
         debug=False,
         file_extension=".pt",
         recursionLevel=2,
         resume=False,
         seqList=None,
         split=None):
    # Test the extension is valid
    if file_extension not in ['.txt', '.npy', '.pt']:
        raise ValueError(
            f'Activation file extension invalid ({file_extension})')

    args = argparse.Namespace(**locals())
    print("=============================================================")
    print(f"Quantizing data from {pathActivations}")
    print("=============================================================")

    # Get splits
    if split:
        assert len(split.split("-")) == 2 \
           and int(split.split("-")[1]) >= int(split.split("-")[0]) >= 1, \
               "SPLIT must be under the form idxSplit-numSplits (numSplits >= idxSplit >= 1), eg. --split 1-20"
        idx_split, num_splits = split.split("-")
        idx_split = int(idx_split)
        num_splits = int(num_splits)

    # Find all sequences
    print("")
    print(f"Looking for all {file_extension} files in {pathActivations}")
    seqNames, _ = findAllSeqs(pathActivations,
                              speaker_level=recursionLevel,
                              extension=file_extension,
                              loadCache=True)
    if len(seqNames) == 0 or not os.path.splitext(
            seqNames[0][1])[1].endswith(file_extension):
        print(
            "Seems like the _seq_cache.txt does not contain the correct extension, reload the file list"
        )
        seqNames, _ = findAllSeqs(pathActivations,
                                  speaker_level=recursionLevel,
                                  extension=file_extension,
                                  loadCache=False)
    print(f"Done! Found {len(seqNames)} files!")

    # Filter specific sequences
    if seqList is not None:
        seqNames = filterSeqs(seqList, seqNames)
        print(f"Done! {len(seqNames)} remaining files after filtering!")
    assert len(seqNames) > 0, \
        "No file to be quantized!"

    # Check if directory exists
    pathOutputDir = Path(pathOutputDir)
    if not pathOutputDir.exists():
        print("")
        print(f"Creating the output directory at {pathOutputDir}")
        pathOutputDir.mkdir(parents=True, exist_ok=True)
    writeArgs(pathOutputDir / "_info_args.json", args)

    # Check if output file exists
    if not split:
        nameOutput = "quantized_outputs.txt"
    else:
        nameOutput = f"quantized_outputs_split_{idx_split}-{num_splits}.txt"
    outputFile = pathOutputDir / nameOutput

    # Get splits
    if split:
        startIdx = len(seqNames) // num_splits * (idx_split - 1)
        if idx_split == num_splits:
            endIdx = len(seqNames)
        else:
            endIdx = min(
                len(seqNames) // num_splits * idx_split, len(seqNames))
        seqNames = seqNames[startIdx:endIdx]
        print("")
        print(
            f"Quantizing split {idx_split} out of {num_splits} splits, "
            f"with {len(seqNames)} files (idx in range({startIdx}, {endIdx}))."
        )

    # Debug mode
    if debug:
        nsamples = 20
        print("")
        print(f"Debug mode activated, only load {nsamples} samples!")
        # shuffle(seqNames)
        seqNames = seqNames[:nsamples]

    # Continue
    addEndLine = False  # to add end line (\n) to first line or not
    if resume:
        if outputFile.exists():
            with open(outputFile, 'r') as f:
                lines = [line for line in f]
            existing_files = set([x.split()[0] for x in lines if x.split()])
            seqNames = [
                s for s in seqNames if os.path.splitext(s[1].split('/')[-1])[0]
                not in existing_files
            ]
            print(
                f"Found existing output file, continue to quantize {len(seqNames)} audio files left!"
            )
            if len(lines) > 0 and not lines[-1].endswith("\n"):
                addEndLine = True
    else:
        print(outputFile, outputFile.exists())
        assert not outputFile.exists(), \
            f"Output file {outputFile} already exists !!! " \
            f"If you want to continue quantizing audio files, please check the --resume option."

    assert len(seqNames) > 0, \
        "No file to be quantized!"

    # Load Clustering args
    pathCheckpoint = Path(pathClusteringCheckpoint)
    assert pathCheckpoint.suffix == ".pt"
    if Path(str(pathCheckpoint.with_suffix('')) + '_args.json').exists():
        pathConfig = Path(str(pathCheckpoint.with_suffix('')) + '_args.json')
    elif (pathCheckpoint.parent / "checkpoint_args.json").exists():
        pathConfig = pathCheckpoint.parent / "checkpoint_args.json"
    else:
        assert False, \
            f"Args file not found in the directory {pathCheckpoint.parent}"
    clustering_args = readArgs(pathConfig)
    print("")
    print(
        f"Clutering args:\n{json.dumps(vars(clustering_args), indent=4, sort_keys=True)}"
    )
    print('-' * 50)

    # Load CluterModule
    print("")
    print(f"Loading ClusterModule at {pathCheckpoint}")
    clusterModule = loadClusterModule(pathCheckpoint)
    if not cpu:
        clusterModule.cuda()
    print("ClusterModule loaded!")

    # Quantization of files
    print("")
    print(f"Quantizing activation files and saving outputs to {outputFile}...")
    f = open(outputFile, "a")
    bar = progressbar.ProgressBar(maxval=len(seqNames))
    bar.start()
    start_time = time()
    for index, vals in enumerate(seqNames):
        bar.update(index)

        file_path = vals[1]
        file_path = os.path.join(pathActivations, file_path)

        # Quantizing
        quantLine = quantize_file(file_path, clusterModule, cpu=cpu)

        # Save the outputs
        file_name = os.path.splitext(os.path.basename(file_path))[0]
        outLine = "\t".join([file_name, quantLine])
        if addEndLine:
            f.write("\n" + outLine)
        else:
            f.write(outLine)
            addEndLine = True
    bar.finish()
    print(f"...done {len(seqNames)} files in {time()-start_time} seconds.")
    f.close()
def main(argv):
    # Args parser
    args = parseArgs(argv)
    
    print("=============================================================")
    print(f"Quantizing data from {args.pathDB}")
    print("=============================================================")

    # Check if directory exists
    if not os.path.exists(args.pathOutput):
        print("")
        print(f"Creating the output directory at {args.pathOutput}")
        Path(args.pathOutput).mkdir(parents=True, exist_ok=True)

    # Get splits
    if args.split:
        assert len(args.split.split("-"))==2 and int(args.split.split("-")[1]) >= int(args.split.split("-")[0]) >= 1, \
            "SPLIT must be under the form idxSplit-numSplits (numSplits >= idxSplit >= 1), eg. --split 1-20"
        idx_split, num_splits = args.split.split("-")
        idx_split = int(idx_split)
        num_splits = int(num_splits)

    # Find all sequences
    print("")
    print(f"Looking for all {args.file_extension} files in {args.pathDB} with speakerLevel {args.recursionLevel}")
    seqNames, speakers = findAllSeqs(args.pathDB,
                                 speaker_level=args.recursionLevel,
                                 extension=args.file_extension,
                                 loadCache=True)

    if args.pathSeq:
        with open(args.pathSeq, 'r') as f:
            seqs = set([x.strip() for x in f])

        filtered = []
        for s in seqNames:
            if s[1].split('/')[-1].split('.')[0] in seqs:
                filtered.append(s)
        seqNames = filtered

    print(f"Done! Found {len(seqNames)} files and {len(speakers)} speakers!")
    if args.separate_speaker:
        seqNames_by_speaker = {}
        for seq in seqNames:
            speaker = seq[1].split("/")[args.recursionLevel-1]
            if speaker not in seqNames_by_speaker:
                seqNames_by_speaker[speaker] = []
            seqNames_by_speaker[speaker].append(seq)

    # Check if output file exists
    if not args.split:
        nameOutput = "quantized_outputs.txt"
    else:
        nameOutput = f"quantized_outputs_split_{idx_split}-{num_splits}.txt"
    if args.separate_speaker is False:
        outputFile = os.path.join(args.pathOutput, nameOutput)
        assert not os.path.exists(outputFile), \
            f"Output file {outputFile} already exists !!!"
    
    # Get splits
    if args.split:
        startIdx = len(seqNames) // num_splits * (idx_split-1)
        if idx_split == num_splits:
            endIdx = len(seqNames)
        else:
            endIdx = min(len(seqNames) // num_splits * idx_split, len(seqNames))
        seqNames = seqNames[startIdx:endIdx]
        print("")
        print(f"Quantizing split {idx_split} out of {num_splits} splits, with {len(seqNames)} files (idx in range({startIdx}, {endIdx})).")

    # Debug mode
    if args.debug:
        nsamples=20
        print("")
        print(f"Debug mode activated, only load {nsamples} samples!")
        # shuffle(seqNames)
        seqNames = seqNames[:nsamples]

    # Load Clustering args
    assert args.pathCheckpoint[-3:] == ".pt"
    if os.path.exists(args.pathCheckpoint[:-3] + "_args.json"):
        pathConfig = args.pathCheckpoint[:-3] + "_args.json"
    elif os.path.exists(os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json")):
        pathConfig = os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json")
    else:
        assert False, \
            f"Args file not found in the directory {os.path.dirname(args.pathCheckpoint)}"
    clustering_args = readArgs(pathConfig)
    print("")
    print(f"Clutering args:\n{json.dumps(vars(clustering_args), indent=4, sort_keys=True)}")
    print('-' * 50)

    # Load CluterModule
    clusterModule = loadClusterModule(args.pathCheckpoint, norm_vec_len=args.norm_vec_len)
    clusterModule.cuda()

    # Load FeatureMaker
    print("")
    print("Loading CPC FeatureMaker")
    if 'level_gru' in vars(clustering_args) and clustering_args.level_gru is not None:
        updateConfig = argparse.Namespace(nLevelsGRU=clustering_args.level_gru)
    else:
        updateConfig = None
    model = loadModel([clustering_args.pathCheckpoint], updateConfig=updateConfig)[0]
    ## If we don't apply batch implementation, we can set LSTM model to keep hidden units
    ## making the quality of the quantized units better
    if args.nobatch:
        model.gAR.keepHidden = True
    featureMaker = FeatureModule(model, clustering_args.encoder_layer)
    if clustering_args.dimReduction is not None:
        dimRed = loadDimReduction(clustering_args.dimReduction, clustering_args.centroidLimits)
        featureMaker = torch.nn.Sequential(featureMaker, dimRed)
    if not clustering_args.train_mode:
        featureMaker.eval()
    featureMaker.cuda()
    def feature_function(x): 
        if args.nobatch is False:
            res0 = buildFeature_batch(featureMaker, x,
                                                    seqNorm=False,
                                                    strict=args.strict,
                                                    maxSizeSeq=args.max_size_seq,
                                                    batch_size=args.batch_size)
            if args.norm_vec_len:
                # [!] we actually used CPC_audio/scripts/quantize_audio.py for that in the end
                res0Lengths = torch.sqrt((res0*res0).sum(2))
                res0 = res0 / res0Lengths.view(*(res0Lengths.shape), 1)
            return res0
        else:
            res0 = buildFeature(featureMaker, x,
                                seqNorm=False,
                                strict=args.strict)
            if args.norm_vec_len:
                # [!] we actually used CPC_audio/scripts/quantize_audio.py for that in the end
                res0Lengths = torch.sqrt((res0*res0).sum(2))
                res0 = res0 / res0Lengths.view(*(res0Lengths.shape), 1)
            return res0
    print("CPC FeatureMaker loaded!")
    
    # Quantization of files
    print("")
    print(f"Quantizing audio files...")
    seqQuantLines = []
    bar = progressbar.ProgressBar(maxval=len(seqNames))
    bar.start()
    start_time = time()
    for index, vals in enumerate(seqNames):
        bar.update(index)

        file_path = vals[1]
        file_path = os.path.join(args.pathDB, file_path)

        # Get features & quantizing
        cFeatures = feature_function(file_path).cuda()

        nGroups = cFeatures.size(-1)//clusterModule.Ck.size(-1)

        cFeatures = cFeatures.view(1, -1, clusterModule.Ck.size(-1))

        if len(vals) > 2 and int(vals[-1]) > 9400000: # Librilight, to avoid OOM
            clusterModule = clusterModule.cpu()
            cFeatures = cFeatures.cpu()
            qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1)
            clusterModule = clusterModule.cuda()
        else:
            qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1)
        qFeatures = qFeatures[0].detach().cpu().numpy()

        # Transform to quantized line
        quantLine = ",".join(["-".join([str(i) for i in item]) for item in qFeatures.reshape(-1, nGroups)])
        seqQuantLines.append(quantLine)

    bar.finish()
    print(f"...done {len(seqQuantLines)} files in {time()-start_time} seconds.")

    # Saving outputs
    print("")
    print(f"Saving outputs to {outputFile}")
    outLines = []
    for vals, quantln in zip(seqNames, seqQuantLines):
        file_path = vals[1]
        file_name = os.path.splitext(os.path.basename(file_path))[0]
        outLines.append("\t".join([file_name, quantln]))
    with open(outputFile, "w") as f:
        f.write("\n".join(outLines))
Exemple #5
0
def main(argv):
    # Args parser
    args = parseArgs(argv)

    print("=============================================================")
    print(f"Quantizing data from {args.pathDB}")
    print("=============================================================")

    # Get splits
    if args.split:
        assert len(args.split.split("-"))==2 and int(args.split.split("-")[1]) >= int(args.split.split("-")[0]) >= 1, \
            "SPLIT must be under the form idxSplit-numSplits (numSplits >= idxSplit >= 1), eg. --split 1-20"
        idx_split, num_splits = args.split.split("-")
        idx_split = int(idx_split)
        num_splits = int(num_splits)

    # Find all sequences
    print("")
    print(f"Looking for all {args.file_extension} files in {args.pathDB}")
    seqNames, _ = findAllSeqs(args.pathDB,
                              speaker_level=1,
                              extension=args.file_extension,
                              loadCache=True)
    if len(seqNames) == 0 or not os.path.splitext(seqNames[0][1])[1].endswith(
            args.file_extension):
        print(
            f"Seems like the _seq_cache.txt does not contain the correct extension, reload the file list"
        )
        seqNames, _ = findAllSeqs(args.pathDB,
                                  speaker_level=1,
                                  extension=args.file_extension,
                                  loadCache=False)
    print(f"Done! Found {len(seqNames)} files!")

    # Filter specific sequences
    if args.pathSeq:
        print("")
        print(f"Filtering seqs in {args.pathSeq}")
        with open(args.pathSeq, 'r') as f:
            seqs = set([x.strip() for x in f])
        filtered = []
        for s in seqNames:
            if os.path.splitext(s[1].split('/')[-1])[0] in seqs:
                filtered.append(s)
        seqNames = filtered
        print(f"Done! {len(seqNames)} files filtered!")

    # Check if directory exists
    if not os.path.exists(args.pathOutputDir):
        print("")
        print(f"Creating the output directory at {args.pathOutputDir}")
        Path(args.pathOutputDir).mkdir(parents=True, exist_ok=True)
    writeArgs(os.path.join(args.pathOutputDir, "_info_args.json"), args)

    # Check if output file exists
    if not args.split:
        nameOutput = "quantized_outputs.txt"
    else:
        nameOutput = f"quantized_outputs_split_{idx_split}-{num_splits}.txt"
    outputFile = os.path.join(args.pathOutputDir, nameOutput)

    # Get splits
    if args.split:
        startIdx = len(seqNames) // num_splits * (idx_split - 1)
        if idx_split == num_splits:
            endIdx = len(seqNames)
        else:
            endIdx = min(
                len(seqNames) // num_splits * idx_split, len(seqNames))
        seqNames = seqNames[startIdx:endIdx]
        print("")
        print(
            f"Quantizing split {idx_split} out of {num_splits} splits, with {len(seqNames)} files (idx in range({startIdx}, {endIdx}))."
        )

    # Debug mode
    if args.debug:
        nsamples = 20
        print("")
        print(f"Debug mode activated, only load {nsamples} samples!")
        # shuffle(seqNames)
        seqNames = seqNames[:nsamples]

    # Continue
    addEndLine = False  # to add end line (\n) to first line or not
    if args.resume:
        if os.path.exists(outputFile):
            with open(outputFile, 'r') as f:
                lines = [line for line in f]
            existing_files = set([x.split()[0] for x in lines if x.split()])
            seqNames = [
                s for s in seqNames if os.path.splitext(s[1].split('/')[-1])[0]
                not in existing_files
            ]
            print(
                f"Found existing output file, continue to quantize {len(seqNames)} audio files left!"
            )
            if len(lines) > 0 and not lines[-1].endswith("\n"):
                addEndLine = True
    else:
        assert not os.path.exists(outputFile), \
            f"Output file {outputFile} already exists !!! If you want to continue quantizing audio files, please check the --resume option."

    assert len(seqNames) > 0, \
        "No file to be quantized!"

    # Load Clustering args
    assert args.pathClusteringCheckpoint[-3:] == ".pt"
    if os.path.exists(args.pathClusteringCheckpoint[:-3] + "_args.json"):
        pathConfig = args.pathClusteringCheckpoint[:-3] + "_args.json"
    elif os.path.exists(
            os.path.join(os.path.dirname(args.pathClusteringCheckpoint),
                         "checkpoint_args.json")):
        pathConfig = os.path.join(
            os.path.dirname(args.pathClusteringCheckpoint),
            "checkpoint_args.json")
    else:
        assert False, \
            f"Args file not found in the directory {os.path.dirname(args.pathClusteringCheckpoint)}"
    clustering_args = readArgs(pathConfig)
    print("")
    print(
        f"Clutering args:\n{json.dumps(vars(clustering_args), indent=4, sort_keys=True)}"
    )
    print('-' * 50)

    # Load CluterModule
    print("")
    print(f"Loading ClusterModule at {args.pathClusteringCheckpoint}")
    clusterModule = loadClusterModule(args.pathClusteringCheckpoint)
    if not args.cpu:
        clusterModule.cuda()
    print("ClusterModule loaded!")

    # Get the CPC checkpoint path from clustering args
    if not os.path.isabs(
            clustering_args.pathCheckpoint):  # Maybe it's relative path
        clustering_args.pathCheckpoint = os.path.join(
            os.path.dirname(os.path.abspath(args.pathClusteringCheckpoint)),
            clustering_args.pathCheckpoint)
    assert os.path.exists(clustering_args.pathCheckpoint), \
        f"CPC path at {clustering_args.pathCheckpoint} does not exist!!"

    # Load FeatureMaker
    print("")
    print(f"Loading CPC FeatureMaker from {clustering_args.pathCheckpoint}")
    ## If we don't apply batch implementation, we can set LSTM model to keep hidden units
    ## making the quality of the quantized units better (that's why I set keep_hidden=args.nobatch)
    featureMaker = loadCPCFeatureMaker(
        clustering_args.pathCheckpoint,
        gru_level=vars(clustering_args).get('level_gru', None),
        get_encoded=clustering_args.encoder_layer,
        keep_hidden=args.nobatch)
    if clustering_args.dimReduction is not None:
        dimRed = loadDimReduction(clustering_args.dimReduction,
                                  clustering_args.centroidLimits)
        featureMaker = torch.nn.Sequential(featureMaker, dimRed)
    if not clustering_args.train_mode:
        featureMaker.eval()
    if not args.cpu:
        featureMaker.cuda()

    def cpc_feature_function(x):
        if args.nobatch is False:
            return buildFeature_batch(featureMaker,
                                      x,
                                      seqNorm=False,
                                      strict=args.strict,
                                      maxSizeSeq=args.max_size_seq,
                                      batch_size=args.batch_size)
        else:
            return buildFeature(featureMaker,
                                x,
                                seqNorm=False,
                                strict=args.strict)

    print("CPC FeatureMaker loaded!")

    # Quantization of files
    print("")
    print(f"Quantizing audio files and saving outputs to {outputFile}...")
    f = open(outputFile, "a")
    bar = progressbar.ProgressBar(maxval=len(seqNames))
    bar.start()
    start_time = time()
    for index, vals in enumerate(seqNames):
        bar.update(index)

        file_path = vals[1]
        file_path = os.path.join(args.pathDB, file_path)

        # Quantizing
        quantLine = quantize_file(file_path, cpc_feature_function,
                                  clusterModule)

        # Save the outputs
        file_name = os.path.splitext(os.path.basename(file_path))[0]
        outLine = "\t".join([file_name, quantLine])
        if addEndLine:
            f.write("\n" + outLine)
        else:
            f.write(outLine)
            addEndLine = True
    bar.finish()
    print(f"...done {len(seqNames)} files in {time()-start_time} seconds.")
    f.close()
Exemple #6
0
def main(argv):

    args = parse_args(argv)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    load_criterion = False

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    model, hidden_gar, hidden_encoder = fl.loadModel(
        args.load, loadStateDict=not args.no_pretraining)
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(args.nGPU))

    dim_features = hidden_encoder if args.get_encoded else hidden_gar

    # Now the criterion
    phone_labels = None
    if args.pathPhone is not None:
        phone_labels, n_phones = parseSeqLabels(args.pathPhone)
        if not args.CTC:
            print(f"Running phone separability with aligned phones")
            criterion = cr.PhoneCriterion(dim_features, n_phones,
                                          args.get_encoded)
        else:
            print(f"Running phone separability with CTC loss")
            criterion = cr.CTCPhoneCriterion(dim_features, n_phones,
                                             args.get_encoded)
    else:
        print(f"Running speaker separability")
        criterion = cr.SpeakerCriterion(dim_features, len(speakers))
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU))

    # Dataset
    seq_train = filterSeqs(args.pathTrain, seqNames)
    seq_val = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seq_train = seq_train[:1000]
        seq_val = seq_val[:100]

    db_train = AudioBatchData(args.pathDB, args.size_window, seq_train,
                              phone_labels, len(speakers))
    db_val = AudioBatchData(args.pathDB, args.size_window, seq_val,
                            phone_labels, len(speakers))

    batch_size = args.batchSizeGPU * args.nGPU

    train_loader = db_train.getDataLoader(batch_size,
                                          "uniform",
                                          True,
                                          numWorkers=0)

    val_loader = db_val.getDataLoader(batch_size,
                                      'sequential',
                                      False,
                                      numWorkers=0)

    # Optimizer
    g_params = list(criterion.parameters())
    model.optimize = False
    model.eval()
    if args.unfrozen:
        print("Working in full fine-tune mode")
        g_params += list(model.parameters())
        model.optimize = True
    else:
        print("Working with frozen features")
        for g in model.parameters():
            g.requires_grad = False

    optimizer = torch.optim.Adam(g_params,
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    # Checkpoint directory
    args.pathCheckpoint = Path(args.pathCheckpoint)
    args.pathCheckpoint.mkdir(exist_ok=True)
    args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint")

    with open(f"{args.pathCheckpoint}_args.json", 'w') as file:
        json.dump(vars(args), file, indent=2)

    run(model, criterion, train_loader, val_loader, optimizer, logs,
        args.n_epoch, args.pathCheckpoint)
    args = parseArgs(sys.argv[1:])
    # Export absolute paths for later use
    args.pathCheckpoint = os.path.abspath(args.pathCheckpoint)
    args.pathOutput = os.path.abspath(args.pathOutput)
    args.pathDB = os.path.abspath(args.pathDB)

    if not args.load:
        assert os.path.exists(args.pathOutput) is False, \
            f"The output file {args.pathOutput} already exists, please check the option --load !"
        assert os.path.exists(os.path.join(os.path.dirname(args.pathOutput), "checkpoint_last.pt")) is False, \
            f"Found last_checkpoint.pt in the output directory, please check the option --load !"

    print(args)
    seqNames, speakers = findAllSeqs(args.pathDB,
                                     speaker_level=args.recursionLevel,
                                     extension=args.extension,
                                     loadCache=True)

    if args.seqList is not None:
        seqNames = filterSeqs(args.seqList, seqNames)
    if args.debug:
        nsamples = 1000
        print(f"Debug mode activated, get only {nsamples} samples!")
        shuffle(seqNames)
        seqNames = seqNames[:nsamples]
    if args.getDistanceEstimation:
        shuffle(seqNames)
        seqNames = seqNames[:5000]

    print("")
    print(f'Loading audio data at {args.pathDB}')
    parser.add_argument('--getEncoded', action='store_true')
    parser.add_argument('--clusters', type=str, default=None)
    parser.add_argument('--seqNorm', action='store_true')

    args = parser.parse_args()

    if not os.path.isdir(args.pathOut):
        os.mkdir(args.pathOut)

    with open(os.path.join(os.path.dirname(args.pathOut),
                           f"{os.path.basename(args.pathOut)}.json"), 'w') \
            as file:
        json.dump(vars(args), file, indent=2)

    outData = [
        x[1] for x in findAllSeqs(
            args.pathDB, extension=args.extension, loadCache=False)[0]
    ]

    featureMaker = loadModel([args.pathCheckpoint])[0]
    stepSize = featureMaker.gEncoder.DOWNSAMPLING / 16000
    print(f"stepSize : {stepSize}")
    featureMaker = FeatureModule(featureMaker, args.getEncoded)
    featureMaker.collapse = False

    if args.addCriterion:
        criterion, nPhones = loadSupervisedCriterion(args.pathCheckpoint)
        featureMaker = ModelPhoneCombined(featureMaker, criterion, nPhones,
                                          args.oneHot)
    if device == "cuda":
        featureMaker = featureMaker.cuda(device=0)
Exemple #9
0
    if args.command == 'per':
        args = get_PER_args(args)

    # Output Directory
    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    name = f"_{args.name}" if args.command == "per" else ""
    pathLogs = os.path.join(args.output, f'logs_{args.command}{name}.txt')
    tee = subprocess.Popen(["tee", pathLogs], stdin=subprocess.PIPE)
    os.dup2(tee.stdin.fileno(), sys.stdout.fileno())

    phoneLabels, nPhones = parseSeqLabels(args.pathPhone)

    inSeqs, _ = findAllSeqs(args.pathDB, extension=args.file_extension)
    # Datasets
    if args.command == 'train' and args.pathTrain is not None:
        seqTrain = filterSeqs(args.pathTrain, inSeqs)
    else:
        seqTrain = inSeqs

    if args.pathVal is None and args.command == 'train':
        random.shuffle(seqTrain)
        sizeTrain = int(0.9 * len(seqTrain))
        seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
    elif args.pathVal is not None:
        seqVal = filterSeqs(args.pathVal, inSeqs)
    else:
        raise RuntimeError("No validation dataset found for PER computation")
Exemple #10
0
def main(args):

    # import ptvsd
    # ptvsd.enable_attach(('0.0.0.0', 7309))
    # print("Attach debugger now")
    # ptvsd.wait_for_attach()

    args = parseArgs(args)

    utils.set_seed(args.random_seed)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    loadOptimizer = False
    os.makedirs(args.pathCheckpoint, exist_ok=True)
    if not args.onlyCapture and not args.only_classif_metric:
        json.dump(vars(args), open(os.path.join(args.pathCheckpoint, 'checkpoint_args.json'), 'wt'))
    if args.pathCheckpoint is not None and not args.restart:
        cdata = fl.getCheckpointData(args.pathCheckpoint)
        if cdata is not None:
            data, logs, locArgs = cdata
            print(f"Checkpoint detected at {data}")
            fl.loadArgs(args, locArgs,
                        forbiddenAttr={"nGPU", "pathCheckpoint",
                                       "debug", "restart", "world_size",
                                       "n_nodes", "node_id", "n_gpu_per_node",
                                       "max_size_loaded"})
            args.load, loadOptimizer = [data], True
            args.loadCriterion = True

    logs["logging_step"] = args.logging_step

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    if not args.onlyCapture or args.only_classif_metric:
        print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers')
        # Datasets
        if args.pathTrain is not None:
            seqTrain = filterSeqs(args.pathTrain, seqNames)
        else:
            seqTrain = seqNames

        if args.pathVal is None:
            random.shuffle(seqTrain)
            sizeTrain = int(0.99 * len(seqTrain))
            seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
            print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val')
        else:
            seqVal = filterSeqs(args.pathVal, seqNames)

    if args.pathCaptureDS is not None:
        assert args.pathCaptureSave is not None
        whatToSave = []
        if args.captureEverything:
            whatToSave = ['conv_repr', 'ctx_repr', 'speaker_align', 'pred']
            if args.path_phone_data:
                whatToSave.append('phone_align')
            if args.CPCCTC:
                whatToSave.append('cpcctc_align')
                whatToSave.append('cpcctc_log_scores')
        else:
            for argVal, name in zip([args.captureConvRepr, 
                                    args.captureCtxRepr, 
                                    args.captureSpeakerAlign, 
                                    args.capturePhoneAlign,
                                    args.capturePred,
                                    args.captureCPCCTCalign,
                                    args.captureCPCCTClogScores], 
                                    ['conv_repr', 'ctx_repr', 'speaker_align', 'phone_align', 'pred', 'cpcctc_align', 'cpcctc_log_scores']):
                if argVal:
                    whatToSave.append(name)
        ###assert len(whatToSave) > 0
        captureOptions = {
            'path': args.pathCaptureSave,
            'eachEpochs': args.captureEachEpochs,
            'what': whatToSave
        }
        seqCapture = filterSeqs(args.pathCaptureDS, seqNames, 
                                percentage=args.captureDSfreq, totalNum=args.captureDStotNr)
        print(f'Capture files: {len(seqCapture)}')
    else:
        seqCapture = None
        captureOptions = None

    if not args.onlyCapture:
        if args.debug:
            seqTrain = seqTrain[-1000:]
            seqVal = seqVal[-100:]

        phoneLabels, nPhones = None, None
        if args.supervised and args.pathPhone is not None:
            print("Loading the phone labels at " + args.pathPhone)
            phoneLabels, nPhones = parseSeqLabels(args.pathPhone)
            print(f"{nPhones} phones found")

        print("")
        print(f'Loading audio data at {args.pathDB}')
        print("Loading the training dataset")
        trainDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqTrain,
                                    phoneLabels,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader,
                                    MAX_SIZE_LOADED=args.max_size_loaded)
        print("Training dataset loaded")
        print("")

        print("Loading the validation dataset")
        valDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqVal,
                                    phoneLabels,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader)
        print("Validation dataset loaded")
        print("")
    else:
        phoneLabels, nPhones = None, None
        trainDataset = None
        valDataset = None

    if seqCapture is not None:

        if args.path_phone_data:
            print("Loading the phone labels at " + args.path_phone_data)
            phoneLabelsForCapture, _ = parseSeqLabels(args.path_phone_data)
        else:
            assert not args.capturePhoneAlign
            phoneLabelsForCapture = None
            
        print("Loading the capture dataset")
        captureDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqCapture,
                                    phoneLabelsForCapture,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader)
        print("Capture dataset loaded")
        print("")

        if args.captureSetStats:
            captureSetStatsCollector = statutil.constructStatCollectorFromSpecs(args.captureSetStats)
        else:
            captureSetStatsCollector = None
    else:
        captureDataset = None
        captureSetStatsCollector = None

    if args.load is not None:
        if args.gru_level is not None and args.gru_level > 0:
            updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level)
        else:
            updateConfig = None


        # loadBestNotLast = args.onlyCapture or args.only_classif_metric
        # could use this option for loading best state when not running actual training
        # but relying on CPC internal acc isn't very reliable
        # [!] caution - because of how they capture checkpoints,
        #     they capture "best in this part of training" as "best" (apart from capturing current state)
        #     so if best is in epoch 100 and training is paused and resumed from checkpoint
        #     in epoch 150, checkpoint from epoch 200 has "best from epoch 150" saved as globally best
        #     (but this is internal-CPC-score best anyway, which is quite vague)
        cpcModel, args.hiddenGar, args.hiddenEncoder = \
            fl.loadModel(args.load, load_nullspace=args.nullspace, updateConfig=updateConfig)
        CPChiddenGar, CPChiddenEncoder = args.hiddenGar, args.hiddenEncoder            

        if args.gru_level is not None and args.gru_level > 0:
            # Keep hidden units at LSTM layers on sequential batches
            if args.nullspace:
                cpcModel.cpc.gAR.keepHidden = True
            else:
                cpcModel.gAR.keepHidden = True

    else:
        # Encoder network
        encoderNet = fl.getEncoder(args)
        # AR Network
        arNet = fl.getAR(args)

        cpcModel = model.CPCModel(encoderNet, arNet)

        CPChiddenGar, CPChiddenEncoder = cpcModel.gAR.getDimOutput(), cpcModel.gEncoder.getDimOutput()

    batchSize = args.nGPU * args.batchSizeGPU
    cpcModel.supervised = args.supervised

    downsampling = cpcModel.cpc.gEncoder.DOWNSAMPLING if isinstance(cpcModel, model.CPCModelNullspace) else cpcModel.gEncoder.DOWNSAMPLING
    # Training criterion
    if args.load is not None and args.loadCriterion:
        cpcCriterion = loadCriterion(args.load[0],  downsampling,
                                     len(speakers), nPhones)
    else:
        cpcCriterion = getCriterion(args, downsampling,
                                    len(speakers), nPhones)

    if loadOptimizer:
        state_dict = torch.load(args.load[0], 'cpu')
        cpcCriterion.load_state_dict(state_dict["cpcCriterion"])

    cpcCriterion.cuda()
    cpcModel.cuda()
    
    # Optimizer
    g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters())

    lr = args.learningRate
    optimizer = torch.optim.Adam(g_params, lr=lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    if loadOptimizer and not args.onlyCapture and not args.only_classif_metric:
        print("Loading optimizer " + args.load[0])
        state_dict = torch.load(args.load[0], 'cpu')
        if "optimizer" in state_dict:
            optimizer.load_state_dict(state_dict["optimizer"])

    # Checkpoint
    if args.pathCheckpoint is not None and not args.onlyCapture and not args.only_classif_metric:
        if not os.path.isdir(args.pathCheckpoint):
            os.mkdir(args.pathCheckpoint)
        args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint")
        with open(args.pathCheckpoint + "_args.json", 'w') as file:
            json.dump(vars(args), file, indent=2)

    scheduler = None
    if args.schedulerStep > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.schedulerStep,
                                                    gamma=0.5)
    if args.schedulerRamp is not None:
        n_epoch = args.schedulerRamp
        print(f"Ramp activated. n_e = {n_epoch}")
        scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                           lr_lambda=lambda epoch: utils.ramp_scheduling_function(
                                                               n_epoch, epoch),
                                                           last_epoch=-1)
        if scheduler is None:
            scheduler = scheduler_ramp
        else:
            scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler],
                                                [0, args.schedulerRamp])
    if scheduler is not None:
        print(f'Redoing {len(logs["epoch"])} scheduler steps')
        for i in range(len(logs["epoch"])):
            scheduler.step()

    print("cpcModel", cpcModel)
    print("cpcCriterion", cpcCriterion)

    cpcModel = torch.nn.DataParallel(cpcModel,
                                     device_ids=range(args.nGPU)).cuda()
    cpcCriterion = torch.nn.DataParallel(cpcCriterion,
                                         device_ids=range(args.nGPU)).cuda()
    
    if args.supervised_classif_metric:

        linsep_batch_size = args.linsepBatchSizeGPU * args.nGPU

        dim_features = CPChiddenEncoder if args.phone_get_encoded else CPChiddenGar
        dim_ctx_features = CPChiddenGar  # for speakers using CNN encodings is not supported; could add but not very useful perhaps

        phoneLabelsData = None
        if args.path_phone_data:
            phoneLabelsData, nPhonesInData = parseSeqLabels(args.path_phone_data)
            
            if not args.CTCphones:
                print(f"Running phone separability with aligned phones")
            else:
                print(f"Running phone separability with CTC loss")

            def constructPhoneCriterionAndOptimizer():
                if not args.CTCphones:
                    # print(f"Running phone separability with aligned phones")
                    phone_criterion = cr.PhoneCriterion(dim_features,
                                                nPhonesInData, args.phone_get_encoded,
                                                nLayers=args.linsep_net_layers)
                else:
                    # print(f"Running phone separability with CTC loss")
                    phone_criterion = cr.CTCPhoneCriterion(dim_features,
                                                    nPhonesInData, args.phone_get_encoded,
                                                    nLayers=args.linsep_net_layers)
                phone_criterion.cuda()
                phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU))

                # Optimizer
                phone_g_params = list(phone_criterion.parameters())

                phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)
                
                return phone_criterion, phone_optimizer
        
        if args.speaker_sep:
            print(f"Running speaker separability")

            def constructSpeakerCriterionAndOptimizer():
                speaker_criterion = cr.SpeakerCriterion(dim_ctx_features, len(speakers),
                                                        nLayers=args.linsep_net_layers)
                speaker_criterion.cuda()
                speaker_criterion = torch.nn.DataParallel(speaker_criterion, device_ids=range(args.nGPU))

                speaker_g_params = list(speaker_criterion.parameters())

                speaker_optimizer = torch.optim.Adam(speaker_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)

                return speaker_criterion, speaker_optimizer

        linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain,
                                phoneLabelsData, len(speakers))
        linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal,
                                    phoneLabelsData, len(speakers))

        linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True,
                                        numWorkers=0)

        linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False,
                                    numWorkers=0)

        def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch):
            log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch))
            if not os.path.exists(log_path_for_epoch):
                os.makedirs(log_path_for_epoch)
            log_path_phoneme = os.path.join(log_path_for_epoch, "phoneme/")
            log_path_speaker = os.path.join(log_path_for_epoch, "speaker/")
            if not os.path.exists(log_path_phoneme):
                os.makedirs(log_path_phoneme)
            if not os.path.exists(log_path_speaker):
                os.makedirs(log_path_speaker)
            if args.linsep_checkpoint_dir:
                checpoint_path_for_epoch = os.path.join(args.linsep_checkpoint_dir, str(numOfEpoch))
                checkpoint_path_phoneme = os.path.join(checpoint_path_for_epoch, "phoneme/")
                checkpoint_path_speaker = os.path.join(checpoint_path_for_epoch, "speaker/")
                if not os.path.exists(checkpoint_path_phoneme):
                    os.makedirs(checkpoint_path_phoneme)
                if not os.path.exists(checkpoint_path_speaker):
                    os.makedirs(checkpoint_path_speaker)
            locLogsPhone = {}
            locLogsSpeaker = {}
            if args.path_phone_data:
                phone_criterion, phone_optimizer = constructPhoneCriterionAndOptimizer()
                locLogsPhone = linsep.trainLinsepClassification(
                    cpcMdl,
                    phone_criterion,  # combined with classification model before
                    linsep_train_loader,
                    linsep_val_loader,
                    phone_optimizer,
                    log_path_phoneme,
                    args.linsep_task_logging_step,
                    checkpoint_path_phoneme,
                    args.linsep_n_epoch,
                    cpcStateEpoch,
                    'phone')
                del phone_criterion
                del phone_optimizer
            if args.speaker_sep:
                speaker_criterion, speaker_optimizer = constructSpeakerCriterionAndOptimizer()
                locLogsSpeaker = linsep.trainLinsepClassification(
                    cpcMdl,
                    speaker_criterion,  # combined with classification model before
                    linsep_train_loader,
                    linsep_val_loader,
                    speaker_optimizer,
                    log_path_speaker,
                    args.linsep_task_logging_step,
                    checkpoint_path_speaker,
                    args.linsep_n_epoch,
                    cpcStateEpoch,
                    'speaker')
                del speaker_criterion
                del speaker_optimizer

            locLogsPhone = {"phone_" + k: v for k, v in locLogsPhone.items()}
            locLogsSpeaker = {"speaker_" + k: v for k, v in locLogsSpeaker.items()}
            return {**locLogsPhone, **locLogsSpeaker}

        linsepClassificationTaskConfig = (args.linsep_classif_each_epochs,
                                            runLinsepClassificationTraining)

    else:
        linsepClassificationTaskConfig = (None, None)

    if not args.onlyCapture and not args.only_classif_metric:
        run(trainDataset,
            valDataset,
            (captureDataset, captureOptions, captureSetStatsCollector),
            linsepClassificationTaskConfig,
            batchSize,
            args.samplingType,
            cpcModel,
            cpcCriterion,
            args.nEpoch,
            args.pathCheckpoint,
            optimizer,
            scheduler,
            logs)
    if args.onlyCapture:  
    # caution [!] - will capture for last checkpoint (last saved state) if checkpoint directory given
    #               to use specific checkpoint provide full checkpoint file path
    #               will use "last state" and not "best in internal CPC accuracy" anyway
        onlyCapture(
            (captureDataset, captureOptions, captureSetStatsCollector),
            batchSize,
            cpcModel,
            cpcCriterion,
            logs)
    if args.only_classif_metric:
    # caution [!] - will use last checkpoint (last saved state) if checkpoint directory given
    #               to use specific checkpoint provide full checkpoint file path
    #               will use "last state" and not "best in internal CPC accuracy" anyway
        trainedEpoch = len(logs["epoch"]) - 1
        # runPhonemeClassificationTraining created above if args.supervised_classif_metric
        runLinsepClassificationTraining(trainedEpoch, cpcModel, trainedEpoch)
Exemple #11
0
def main(argv):
    # Args parser
    args = parseArgs(argv)

    print("=============================================================")
    print(f"Building CPC features from {args.pathDB}")
    print("=============================================================")

    # Find all sequences
    print("")
    print(f"Looking for all {args.file_extension} files in {args.pathDB}")
    seqNames, _ = findAllSeqs(args.pathDB,
                              speaker_level=1,
                              extension=args.file_extension,
                              loadCache=True)
    if len(seqNames) == 0 or not os.path.splitext(seqNames[0][-1])[1].endswith(
            args.file_extension):
        print(
            f"Seems like the _seq_cache.txt does not contain the correct extension, reload the file list"
        )
        seqNames, _ = findAllSeqs(args.pathDB,
                                  speaker_level=1,
                                  extension=args.file_extension,
                                  loadCache=False)
    print(f"Done! Found {len(seqNames)} files!")

    # Verify the output directory
    if os.path.exists(args.pathOutputDir):
        existing_files = set([
            os.path.splitext(os.path.basename(x))[0]
            for x in os.listdir(args.pathOutputDir) if x[-4:] == ".npy"
        ])
        seqNames = [
            s for s in seqNames if os.path.splitext(os.path.basename(s[1]))[0]
            not in existing_files
        ]
        print(
            f"Found existing output directory at {args.pathOutputDir}, continue to build features of {len(seqNames)} audio files left!"
        )
    else:
        print("")
        print(f"Creating the output directory at {args.pathOutputDir}")
        Path(args.pathOutputDir).mkdir(parents=True, exist_ok=True)
    writeArgs(os.path.join(args.pathOutputDir, "_info_args.json"), args)

    # Debug mode
    if args.debug:
        nsamples = 20
        print("")
        print(f"Debug mode activated, only load {nsamples} samples!")
        # shuffle(seqNames)
        seqNames = seqNames[:nsamples]

    # Load CPC feature maker
    print("")
    print(f"Loading CPC featureMaker from {args.pathCPCCheckpoint}")
    featureMaker = loadCPCFeatureMaker(args.pathCPCCheckpoint,
                                       gru_level=args.gru_level,
                                       get_encoded=args.get_encoded,
                                       keep_hidden=True)
    featureMaker.eval()
    if not args.cpu:
        featureMaker.cuda()
    print("CPC FeatureMaker loaded!")

    # Define CPC_feature_function
    def CPC_feature_function(x):
        CPC_features = buildFeature(featureMaker,
                                    x,
                                    seqNorm=args.seq_norm,
                                    strict=args.strict,
                                    maxSizeSeq=args.max_size_seq)
        return CPC_features.squeeze(0).float().cpu().numpy()

    # Building features
    print("")
    print(
        f"Building CPC features and saving outputs to {args.pathOutputDir}...")
    bar = progressbar.ProgressBar(maxval=len(seqNames))
    bar.start()
    start_time = time()

    file_out = os.path.join(args.pathOutputDir, file_name)
    for index, vals in enumerate(seqNames):
        bar.update(index)

        file_path = vals[1]
        file_path = os.path.join(args.pathDB, file_path)

        # Computing features
        CPC_features = CPC_feature_function(file_path)

        # Save the outputs
        file_name = os.path.splitext(
            os.path.basename(file_path))[0] + ".ark.gz"
        with WriteHelper(f"ark:| gzip -c > {file_name}") as writer:
            writer('arr_0', CPC_features)
    bar.finish()
    print(f"...done {len(seqNames)} files in {time()-start_time} seconds.")
Exemple #12
0
def main(pathCheckpoint, pathDB, pathOutputDir, batch_size=8, debug=False,
         file_extension='.wav', layer='all', max_size_seq=64000,
         output_file_extension='.txt', recursionLevel=2, seqList=None,
         audio_features_fn='mfcc_features.pt',
         image_features_fn='resnet_features.pt', cpc_model_path=None,
         cpc_gru_level=-1, zr_format=False):

    args = argparse.Namespace(**locals())
    print("=============================================================")
    print(f"Extract activations from VG model for {pathDB}")
    print("=============================================================")

    # Initializing feature extraction config
    # /!\ Code duplication with preprocessing.py
    # Should probably store the feature config on disk.
    _audio_feat_config = dict(type='mfcc', delta=True, alpha=0.97, n_filters=40,
                              window_size=0.025, frame_shift=0.010, audio_features_fn=audio_features_fn)
    _images_feat_config = dict(model='resnet', image_features_fn=image_features_fn)

    if cpc_model_path is not None:
        if audio_features_fn == 'mfcc_features.pt':
            audio_features_fn = 'cpc_features.pt'
        _audio_feat_config = dict(type='cpc', model_path=cpc_model_path, audio_features_fn=audio_features_fn,
                                  strict=False, seq_norm=False, max_size_seq=10240, gru_level=cpc_gru_level,
                                  on_gpu=True)

    # Find all sequences
    print("")
    print(f"Looking for all {file_extension} files in {pathDB}")
    seqNames, _ = findAllSeqs(pathDB,
                              speaker_level=recursionLevel,
                              extension=file_extension,
                              loadCache=True)
    if len(seqNames) == 0 or not os.path.splitext(seqNames[0][-1])[1].endswith(file_extension):
        print("Seems like the _seq_cache.txt does not contain the correct extension, reload the file list")
        seqNames, _ = findAllSeqs(pathDB,
                                  speaker_level=recursionLevel,
                                  extension=file_extension,
                                  loadCache=False)
    print(f"Done! Found {len(seqNames)} files!")

    # Filter specific sequences
    if seqList is not None:
        seqNames = filterSeqs(seqList, seqNames)
        print(f"Done! {len(seqNames)} remaining files after filtering!")
    assert len(seqNames) > 0, \
        "No file to be processed!"

    pathOutputDir = Path(pathOutputDir)
    print("")
    print(f"Creating the output directory at {pathOutputDir}")
    pathOutputDir.mkdir(parents=True, exist_ok=True)
    writeArgs(pathOutputDir / "_info_args.json", args)

    # Debug mode
    if debug:
        nsamples = 20
        print("")
        print(f"Debug mode activated, only load {nsamples} samples!")
        # shuffle(seqNames)
        seqNames = seqNames[:nsamples]

    # Loading audio features
    print("")
    print(f"Loading audio features for {pathDB}")
    pathDB = Path(pathDB)
    if seqList is None:
        cache_fpath = pathDB / args.audio_features_fn
        if cache_fpath.exists():
            print(f"Found cached features ({cache_fpath}). Loading them.")
            features = torch.load(cache_fpath)
        else:
            print('No cached features. Computing them from scratch.')
            audio_fpaths = [pathDB / s[1] for s in seqNames]
            features = compute_audio_features(audio_fpaths, max_size_seq, _audio_feat_config)
            print(f'Caching features ({cache_fpath}).')
            torch.save(features, cache_fpath)
    else:
        print('Computing features.')
        audio_fpaths = [pathDB / s[1] for s in seqNames]
        features = compute_audio_features(audio_fpaths, max_size_seq, _audio_feat_config)


    # Load VG model
    print("")
    print(f"Loading VG model from {pathCheckpoint}")
    vg_model = torch.load(pathCheckpoint)
    print("VG model loaded!")

    # Extracting activations
    print("")
    print(f"Extracting activations and saving outputs to {pathOutputDir}...")
    data = torch.utils.data.DataLoader(dataset=features,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=lambda x: dataset.batch_audio(x, max_frames=None))

    i_next = 0
    zr_keywords = ['phonetic', 'lexical', 'syntactic', 'semantic']
    if zr_format:
        splitted_path = str(pathDB).split('/')
        for keyword in zr_keywords:
            if keyword in splitted_path:
                keyword_idx = splitted_path.index(keyword)
                break
        suffix = '/'.join(splitted_path[keyword_idx:])
    else:
        suffix = ""

    for au, l in tqdm(data):
        activations = vg_model.SpeechEncoder.introspect(au.cuda(), l.cuda())
        fnames = [s[1] for s in seqNames[i_next: i_next + batch_size]]
        if layer == 'all':
            for k in activations:
                save_activations(activations[k], pathOutputDir / k / suffix, fnames,
                                 output_file_extension)
        elif layer in activations:
            save_activations(activations[layer],
                             pathOutputDir / layer / suffix, fnames,
                             output_file_extension)
        i_next += batch_size
def parse_args(argv):
    seqNames, speakers = findAllSeqs('../../../CPC_librispeech/dataset/LibriSpeech/train-clean-100/',
                                 extension=".flac",
                                 loadCache=False)
    
    parser = argparse.ArgumentParser(description='Linear separability trainer')
    parser.add_argument('--pathDB', type=str, default='../../../CPC_librispeech/dataset/LibriSpeech/train-clean-100/',
                        help="Path to the directory containing the audio data.")
    parser.add_argument('--pathTrain', type=str, default=None,
                        help="Path to the list of the training sequences.")
    parser.add_argument('--pathVal', type=str, default='../../../CPC_librispeech/dataset/LibriSpeech/train-clean-100/_seqs_cache.txt',
                        help="Path to the list of the test sequences.")
    parser.add_argument('--load', type=str, nargs='*', default=['../exp_100_gru_linear/checkpoint_95.pt'],
                        help="Path to the checkpoint to evaluate.")
    parser.add_argument('--pathPhone', type=str, default=None,
                        help="Path to the phone labels. If given, will"
                        " compute the phone separability.")
    parser.add_argument('--CTC', default=False,
                        help="Use the CTC loss (for phone separability only)")
    parser.add_argument('--pathCheckpoint', type=str, default='exp_linear_separability_out',
                        help="Path of the output directory where the "
                        " checkpoints should be dumped.")
    parser.add_argument('--nGPU', type=int, default=-1,
                        help='Bumber of GPU. Default=-1, use all available '
                        'GPUs')
    parser.add_argument('--batchSizeGPU', type=int, default=8,
                        help='Batch size per GPU.')
    parser.add_argument('--n_epoch', type=int, default=10)
    parser.add_argument('--debug', action='store_true',
                        help='If activated, will load only a small number '
                        'of audio data.')
    parser.add_argument('--unfrozen', default=False,
                        help="If activated, update the feature network as well"
                        " as the linear classifier")
    parser.add_argument('--no_pretraining', default=False,
                        help="If activated, work from an untrained model.")
    parser.add_argument('--file_extension', type=str, default=".flac",
                        help="Extension of the audio files in pathDB.")
    parser.add_argument('--save_step', type=int, default=-1,
                        help="Frequency at which a checkpoint should be saved,"
                        " et to -1 (default) to save only the best checkpoint.")
    parser.add_argument('--get_encoded', action='store_true',
                        help="If activated, will work with the output of the "
                        " convolutional encoder (see CPC's architecture).")
    parser.add_argument('--lr', type=float, default=2e-4,
                        help='Learning rate.')
    parser.add_argument('--beta1', type=float, default=0.9,
                        help='Value of beta1 for the Adam optimizer.')
    parser.add_argument('--beta2', type=float, default=0.999,
                        help='Value of beta2 for the Adam optimizer.')
    parser.add_argument('--epsilon', type=float, default=2e-8,
                        help='Value of epsilon for the Adam optimizer.')
    parser.add_argument('--ignore_cache', default=False,
                        help="Activate if the sequences in pathDB have"
                        " changed.")
    parser.add_argument('--size_window', type=int, default=20480,
                        help="Number of frames to consider in each batch.")
    args = parser.parse_args(argv)
    if args.nGPU < 0:
        args.nGPU = torch.cuda.device_count()
    if args.save_step <= 0:
        args.save_step = args.n_epoch

    args.load = [str(Path(x).resolve()) for x in args.load]
    args.pathCheckpoint = str(Path(args.pathCheckpoint).resolve())

    return args
Exemple #14
0
def main(args):
    args = parseArgs(args)

    utils.set_seed(args.random_seed)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    loadOptimizer = False
    if args.pathCheckpoint is not None and not args.restart:
        cdata = fl.getCheckpointData(args.pathCheckpoint)
        if cdata is not None:
            data, logs, locArgs = cdata
            print(f"Checkpoint detected at {data}")
            fl.loadArgs(args,
                        locArgs,
                        forbiddenAttr={
                            "nGPU", "pathCheckpoint", "debug", "restart",
                            "world_size", "n_nodes", "node_id",
                            "n_gpu_per_node", "max_size_loaded"
                        })
            args.load, loadOptimizer = [data], True
            args.loadCriterion = True

    logs["logging_step"] = args.logging_step

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers')
    # Datasets
    if args.pathTrain is not None:
        seqTrain = filterSeqs(args.pathTrain, seqNames)
    else:
        seqTrain = seqNames

    if args.pathVal is None:
        random.shuffle(seqTrain)
        sizeTrain = int(0.99 * len(seqTrain))
        seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
        print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val')
    else:
        seqVal = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seqTrain = seqTrain[-1000:]
        seqVal = seqVal[-100:]

    phoneLabels, nPhones = None, None
    if args.supervised and args.pathPhone is not None:
        print("Loading the phone labels at " + args.pathPhone)
        phoneLabels, nPhones = parseSeqLabels(args.pathPhone)
        print(f"{nPhones} phones found")

    print("")
    print(f'Loading audio data at {args.pathDB}')
    print("Loading the training dataset")
    trainDataset = AudioBatchData(args.pathDB,
                                  args.sizeWindow,
                                  seqTrain,
                                  phoneLabels,
                                  len(speakers),
                                  nProcessLoader=args.n_process_loader,
                                  MAX_SIZE_LOADED=args.max_size_loaded)
    print("Training dataset loaded")
    print("")

    print("Loading the validation dataset")
    valDataset = AudioBatchData(args.pathDB,
                                args.sizeWindow,
                                seqVal,
                                phoneLabels,
                                len(speakers),
                                nProcessLoader=args.n_process_loader)
    print("Validation dataset loaded")
    print("")

    if args.load is not None:
        cpcModel, args.hiddenGar, args.hiddenEncoder = \
            fl.loadModel(args.load)

    else:
        # Encoder network
        encoderNet = fl.getEncoder(args)
        # AR Network
        arNet = fl.getAR(args)

        cpcModel = model.CPCModel(encoderNet, arNet)

    batchSize = args.nGPU * args.batchSizeGPU
    cpcModel.supervised = args.supervised

    # Training criterion
    if args.load is not None and args.loadCriterion:
        cpcCriterion = loadCriterion(args.load[0],
                                     cpcModel.gEncoder.DOWNSAMPLING,
                                     len(speakers), nPhones)
    else:
        cpcCriterion = getCriterion(args, cpcModel.gEncoder.DOWNSAMPLING,
                                    len(speakers), nPhones)

    if loadOptimizer:
        state_dict = torch.load(args.load[0], 'cpu')
        cpcCriterion.load_state_dict(state_dict["cpcCriterion"])

    cpcCriterion.cuda()
    cpcModel.cuda()

    # Optimizer
    g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters())

    lr = args.learningRate
    optimizer = torch.optim.Adam(g_params,
                                 lr=lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    if loadOptimizer:
        print("Loading optimizer " + args.load[0])
        state_dict = torch.load(args.load[0], 'cpu')
        if "optimizer" in state_dict:
            optimizer.load_state_dict(state_dict["optimizer"])

    # Checkpoint
    if args.pathCheckpoint is not None:
        if not os.path.isdir(args.pathCheckpoint):
            os.mkdir(args.pathCheckpoint)
        args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint")

    scheduler = None
    if args.schedulerStep > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.schedulerStep,
                                                    gamma=0.5)
    if args.schedulerRamp is not None:
        n_epoch = args.schedulerRamp
        print(f"Ramp activated. n_e = {n_epoch}")
        scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda epoch: utils.ramp_scheduling_function(
                n_epoch, epoch),
            last_epoch=-1)
        if scheduler is None:
            scheduler = scheduler_ramp
        else:
            scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler],
                                                [0, args.schedulerRamp])
    if scheduler is not None:
        for i in range(len(logs["epoch"])):
            scheduler.step()

    cpcModel = torch.nn.DataParallel(cpcModel,
                                     device_ids=range(args.nGPU)).cuda()
    cpcCriterion = torch.nn.DataParallel(cpcCriterion,
                                         device_ids=range(args.nGPU)).cuda()

    run(trainDataset, valDataset, batchSize, args.samplingType, cpcModel,
        cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler,
        logs)
def main(pathActivations, pathOutput, nGroups=1, nClusters=50, MAX_ITER=100,
         batchSizeGPU=50, debug=False, extension='.pt', getDistanceEstimation=False,
         load=False, perIterSize=-1, recursionLevel=2, save=False, save_last=5,
         seqList=None):
    # Test the extension is valid
    if extension not in ['.txt', '.npy', '.pt']:
        raise ValueError(f'Activation file extension invalid ({extension})')

    torch.cuda.empty_cache()

    args = argparse.Namespace(**locals())
    # Export absolute paths for later use
    pathActivations = os.path.abspath(pathActivations)
    pathOutput = os.path.abspath(pathOutput)

    if not load:
        assert os.path.exists(pathOutput) is False, \
            f"The output file {pathOutput} already exists, please check the option --load !"
        assert os.path.exists(os.path.join(os.path.dirname(pathOutput), "checkpoint_last.pt")) is False, \
            "Found last_checkpoint.pt in the output directory, please check the option --load !"

    print(args)
    seqNames, speakers = findAllSeqs(pathActivations,
                                     speaker_level=recursionLevel,
                                     extension=extension,
                                     loadCache=True)

    if seqList is not None:
        seqNames = filterSeqs(seqList, seqNames)
    if debug:
        nsamples = 1000
        print(f"Debug mode activated, get only {nsamples} samples!")
        shuffle(seqNames)
        seqNames = seqNames[:nsamples]
    if getDistanceEstimation:
        shuffle(seqNames)
        seqNames = seqNames[:5000]

    print("")
    print(f'Loading activations at {pathActivations}')
    start_time = time.time()
    dataset = SequentialData(pathActivations, seqNames, None)
    print(f"Dataset loaded in {time.time()-start_time} seconds !")
    print("")

    nGPUs = torch.cuda.device_count()
    if nGPUs == 0:
        raise RuntimeError('No GPU found')
    batchSize = batchSizeGPU * nGPUs
    dataloader = dataset.getDataLoader(batchSize, numWorkers=0)
    print(f"Length of dataLoader: {len(dataloader)}")
    print("")

    # Check if dir exists
    if not os.path.exists(os.path.dirname(pathOutput)) and os.path.dirname(pathOutput):
        Path(os.path.dirname(pathOutput)).mkdir(parents=True, exist_ok=True)

    pathConfig = f"{os.path.splitext(pathOutput)[0]}_args.json"
    with open(pathConfig, 'w') as file:
        json.dump(vars(args), file, indent=2)

    out_state_dict = {}
    print("Starting the clustering...")
    start_time = time.time()
    # Using a dumb lambda function to skip feature extraction as we start from
    # the activations
    clusters = kMeanGPU(dataloader, lambda x: x, nClusters, nGroups,
                        perIterSize=perIterSize,
                        MAX_ITER=MAX_ITER,
                        save=save, load=load,
                        save_dir=os.path.dirname(pathOutput),
                        save_last=save_last,
                        ).cpu()

    print(f'Ran clustering '
          f'in {time.time() - start_time:.2f} seconds')

    clusterModule = kMeanCluster(clusters)
    out_state_dict["state_dict"] = clusterModule.state_dict()
    out_state_dict["n_clusters"] = nClusters
    out_state_dict['dim'] = clusters.size(2)
    torch.save(out_state_dict, pathOutput)
    with open(pathConfig, 'w') as file:
        json.dump(vars(args), file, indent=2)
Exemple #16
0
import cpc.feature_loader as fl
from cpc.eval.sungkyun_classifier import MLP, MobileNetV2, SpeakerClf
#from cpc.eval.sungkyun_libri_sel_dataloader import LibriSelectionDataset
from cpc.dataset import AudioBatchData, findAllSeqs, filterSeqs, parseSeqLabels

torch.multiprocessing.set_sharing_strategy('file_system')
# This was required for preventing multiprocessing errors.
"""Config."""
SEL_FEAT = 'c'  # or 'c' or 'z'
CPC_CHECKPOINT_PATH = '../exp_100_gru_linear/checkpoint_95.pt'  #../exp_100_lstm_transformer_unsup/'
MAX_EPOCHS = 2000

nGPU = torch.cuda.device_count()
"""Data loading..."""
seqNames, speakers = findAllSeqs(
    '../../../CPC_librispeech/dataset/LibriSpeech/train-clean-100/',
    extension=".flac",
    loadCache=False)
seqTrain = []
seqTest = []
cnt_spk = np.zeros(len(speakers))
for seqn in seqNames:
    spk_id, fname = seqn
    cnt_spk[spk_id] += 1
    if cnt_spk[spk_id] > 60:
        seqTest.append(seqn)
    else:
        seqTrain.append(seqn)

# seq_train = filterSeqs(args.pathTrain, seqNames)
# seq_val = filterSeqs(args.pathVal, seqNames)