def main(pathClusteringCheckpoint, pathActivations, pathOutputDir, cpu=False, debug=False, file_extension=".pt", recursionLevel=2, resume=False, seqList=None, split=None): # Test the extension is valid if file_extension not in ['.txt', '.npy', '.pt']: raise ValueError( f'Activation file extension invalid ({file_extension})') args = argparse.Namespace(**locals()) print("=============================================================") print(f"Quantizing data from {pathActivations}") print("=============================================================") # Get splits if split: assert len(split.split("-")) == 2 \ and int(split.split("-")[1]) >= int(split.split("-")[0]) >= 1, \ "SPLIT must be under the form idxSplit-numSplits (numSplits >= idxSplit >= 1), eg. --split 1-20" idx_split, num_splits = split.split("-") idx_split = int(idx_split) num_splits = int(num_splits) # Find all sequences print("") print(f"Looking for all {file_extension} files in {pathActivations}") seqNames, _ = findAllSeqs(pathActivations, speaker_level=recursionLevel, extension=file_extension, loadCache=True) if len(seqNames) == 0 or not os.path.splitext( seqNames[0][1])[1].endswith(file_extension): print( "Seems like the _seq_cache.txt does not contain the correct extension, reload the file list" ) seqNames, _ = findAllSeqs(pathActivations, speaker_level=recursionLevel, extension=file_extension, loadCache=False) print(f"Done! Found {len(seqNames)} files!") # Filter specific sequences if seqList is not None: seqNames = filterSeqs(seqList, seqNames) print(f"Done! {len(seqNames)} remaining files after filtering!") assert len(seqNames) > 0, \ "No file to be quantized!" # Check if directory exists pathOutputDir = Path(pathOutputDir) if not pathOutputDir.exists(): print("") print(f"Creating the output directory at {pathOutputDir}") pathOutputDir.mkdir(parents=True, exist_ok=True) writeArgs(pathOutputDir / "_info_args.json", args) # Check if output file exists if not split: nameOutput = "quantized_outputs.txt" else: nameOutput = f"quantized_outputs_split_{idx_split}-{num_splits}.txt" outputFile = pathOutputDir / nameOutput # Get splits if split: startIdx = len(seqNames) // num_splits * (idx_split - 1) if idx_split == num_splits: endIdx = len(seqNames) else: endIdx = min( len(seqNames) // num_splits * idx_split, len(seqNames)) seqNames = seqNames[startIdx:endIdx] print("") print( f"Quantizing split {idx_split} out of {num_splits} splits, " f"with {len(seqNames)} files (idx in range({startIdx}, {endIdx}))." ) # Debug mode if debug: nsamples = 20 print("") print(f"Debug mode activated, only load {nsamples} samples!") # shuffle(seqNames) seqNames = seqNames[:nsamples] # Continue addEndLine = False # to add end line (\n) to first line or not if resume: if outputFile.exists(): with open(outputFile, 'r') as f: lines = [line for line in f] existing_files = set([x.split()[0] for x in lines if x.split()]) seqNames = [ s for s in seqNames if os.path.splitext(s[1].split('/')[-1])[0] not in existing_files ] print( f"Found existing output file, continue to quantize {len(seqNames)} audio files left!" ) if len(lines) > 0 and not lines[-1].endswith("\n"): addEndLine = True else: print(outputFile, outputFile.exists()) assert not outputFile.exists(), \ f"Output file {outputFile} already exists !!! " \ f"If you want to continue quantizing audio files, please check the --resume option." assert len(seqNames) > 0, \ "No file to be quantized!" # Load Clustering args pathCheckpoint = Path(pathClusteringCheckpoint) assert pathCheckpoint.suffix == ".pt" if Path(str(pathCheckpoint.with_suffix('')) + '_args.json').exists(): pathConfig = Path(str(pathCheckpoint.with_suffix('')) + '_args.json') elif (pathCheckpoint.parent / "checkpoint_args.json").exists(): pathConfig = pathCheckpoint.parent / "checkpoint_args.json" else: assert False, \ f"Args file not found in the directory {pathCheckpoint.parent}" clustering_args = readArgs(pathConfig) print("") print( f"Clutering args:\n{json.dumps(vars(clustering_args), indent=4, sort_keys=True)}" ) print('-' * 50) # Load CluterModule print("") print(f"Loading ClusterModule at {pathCheckpoint}") clusterModule = loadClusterModule(pathCheckpoint) if not cpu: clusterModule.cuda() print("ClusterModule loaded!") # Quantization of files print("") print(f"Quantizing activation files and saving outputs to {outputFile}...") f = open(outputFile, "a") bar = progressbar.ProgressBar(maxval=len(seqNames)) bar.start() start_time = time() for index, vals in enumerate(seqNames): bar.update(index) file_path = vals[1] file_path = os.path.join(pathActivations, file_path) # Quantizing quantLine = quantize_file(file_path, clusterModule, cpu=cpu) # Save the outputs file_name = os.path.splitext(os.path.basename(file_path))[0] outLine = "\t".join([file_name, quantLine]) if addEndLine: f.write("\n" + outLine) else: f.write(outLine) addEndLine = True bar.finish() print(f"...done {len(seqNames)} files in {time()-start_time} seconds.") f.close()
def main(argv): args = parse_args(argv) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} load_criterion = False seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) if args.model == "cpc": def loadCPCFeatureMaker(pathCheckpoint, gru_level=-1, get_encoded=False, keep_hidden=True): """ Load CPC Feature Maker from CPC checkpoint file. """ # Set LSTM level if gru_level is not None and gru_level > 0: updateConfig = argparse.Namespace(nLevelsGRU=gru_level) else: updateConfig = None # Load CPC model model, nHiddenGar, nHiddenEncoder = fl.loadModel(pathCheckpoint, updateConfig=updateConfig) # Keep hidden units at LSTM layers on sequential batches model.gAR.keepHidden = keep_hidden # Build CPC Feature Maker from CPC model #featureMaker = fl.FeatureModule(model, get_encoded=get_encoded) #return featureMaker return model, nHiddenGar, nHiddenEncoder if args.gru_level is not None and args.gru_level > 0: model, hidden_gar, hidden_encoder = loadCPCFeatureMaker(args.load, gru_level=args.gru_level) else: model, hidden_gar, hidden_encoder = fl.loadModel(args.load, loadStateDict=not args.no_pretraining) dim_features = hidden_encoder if args.get_encoded else hidden_gar else: sys.path.append(os.path.abspath(args.path_fairseq)) from fairseq import checkpoint_utils def loadCheckpoint(path_checkpoint, path_data): """ Load lstm_lm model from checkpoint. """ # Set up the args Namespace model_args = argparse.Namespace( task="language_modeling", output_dictionary_size=-1, data=path_data, path=path_checkpoint ) # Load model models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path]) model = models[0] return model model = loadCheckpoint(args.load[0], args.pathDB) dim_features = 768 dim_inter = args.dim_inter # Now the criterion if args.mode == "phonemes_nullspace" or args.mode == "speakers_nullspace": speakers_factorized = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers)) speakers_factorized.load_state_dict(torch.load(args.path_speakers_factorized)["cpcCriterion"]) for param in speakers_factorized.parameters(): param.requires_grad = False def my_nullspace(At, rcond=None): ut, st, vht = torch.Tensor.svd(At, some=False,compute_uv=True) vht=vht.T Mt, Nt = ut.shape[0], vht.shape[1] if rcond is None: rcondt = torch.finfo(st.dtype).eps * max(Mt, Nt) tolt = torch.max(st) * rcondt numt= torch.sum(st > tolt, dtype=int) nullspace = vht[numt:,:].T.cpu().conj() # nullspace.backward(torch.ones_like(nullspace),retain_graph=True) return nullspace dim_features = dim_features - dim_inter nullspace = my_nullspace(speakers_factorized.linearSpeakerClassifier[0].weight) model = CPCModelNullspace(model, nullspace) phone_labels = None if args.pathPhone is not None: phone_labels, n_phones = parseSeqLabels(args.pathPhone) label_key = 'phone' if not args.CTC: print(f"Running phone separability with aligned phones") criterion = cr.PhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running phone separability with CTC loss") criterion = cr.CTCPhoneCriterion(dim_features, n_phones, args.get_encoded) else: label_key = 'speaker' print(f"Running speaker separability") if args.mode == "speakers_factorized": criterion = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers)) else: criterion = cr.SpeakerCriterion(dim_features, len(speakers)) criterion.cuda() criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU)) model.cuda() model = torch.nn.DataParallel(model, device_ids=range(args.nGPU)) # Dataset seq_train = filterSeqs(args.pathTrain, seqNames) seq_val = filterSeqs(args.pathVal, seqNames) if args.debug: seq_train = seq_train[:1000] seq_val = seq_val[:100] db_train = AudioBatchData(args.pathDB, args.size_window, seq_train, phone_labels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) db_val = AudioBatchData(args.pathDB, args.size_window, seq_val, phone_labels, len(speakers), nProcessLoader=args.n_process_loader) batch_size = args.batchSizeGPU * args.nGPU train_loader = db_train.getDataLoader(batch_size, "uniform", True, numWorkers=0) val_loader = db_val.getDataLoader(batch_size, 'sequential', False, numWorkers=0) # Optimizer g_params = list(criterion.parameters()) model.optimize = False model.eval() if args.unfrozen: print("Working in full fine-tune mode") g_params += list(model.parameters()) model.optimize = True else: print("Working with frozen features") for g in model.parameters(): g.requires_grad = False optimizer = torch.optim.Adam(g_params, lr=args.lr, betas=(args.beta1, args.beta2), eps=args.epsilon) # Checkpoint directory args.pathCheckpoint = Path(args.pathCheckpoint) args.pathCheckpoint.mkdir(exist_ok=True) args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint") with open(f"{args.pathCheckpoint}_args.json", 'w') as file: json.dump(vars(args), file, indent=2) if args.centerpushFile: clustersFileExt = args.centerpushFile.split('.')[-1] assert clustersFileExt in ('pt', 'npy', 'txt') if clustersFileExt == 'npy': centers = np.load(args.centerpushFile) elif clustersFileExt == 'txt': centers = np.genfromtxt(args.centerpushFile) elif clustersFileExt == 'pt': # assuming it's a checkpoint centers = torch.load(args.centerpushFile, map_location=torch.device('cpu'))['state_dict']['Ck'] centers = torch.reshape(centers, centers.shape[1:]).numpy() centers = torch.tensor(centers).cuda() centerpushSettings = (centers, args.centerpushDeg) else: centerpushSettings = None run(model, criterion, train_loader, val_loader, optimizer, logs, args.n_epoch, args.pathCheckpoint, label_key=label_key, centerpushSettings=centerpushSettings)
def main(argv): args = parse_args(argv) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} load_criterion = False seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) model, hidden_gar, hidden_encoder = fl.loadModel( args.load, loadStateDict=not args.no_pretraining) model.cuda() model = torch.nn.DataParallel(model, device_ids=range(args.nGPU)) dim_features = hidden_encoder if args.get_encoded else hidden_gar # Now the criterion phone_labels = None if args.pathPhone is not None: phone_labels, n_phones = parseSeqLabels(args.pathPhone) if not args.CTC: print(f"Running phone separability with aligned phones") criterion = cr.PhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running phone separability with CTC loss") criterion = cr.CTCPhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running speaker separability") criterion = cr.SpeakerCriterion(dim_features, len(speakers)) criterion.cuda() criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU)) # Dataset seq_train = filterSeqs(args.pathTrain, seqNames) seq_val = filterSeqs(args.pathVal, seqNames) if args.debug: seq_train = seq_train[:1000] seq_val = seq_val[:100] db_train = AudioBatchData(args.pathDB, args.size_window, seq_train, phone_labels, len(speakers)) db_val = AudioBatchData(args.pathDB, args.size_window, seq_val, phone_labels, len(speakers)) batch_size = args.batchSizeGPU * args.nGPU train_loader = db_train.getDataLoader(batch_size, "uniform", True, numWorkers=0) val_loader = db_val.getDataLoader(batch_size, 'sequential', False, numWorkers=0) # Optimizer g_params = list(criterion.parameters()) model.optimize = False model.eval() if args.unfrozen: print("Working in full fine-tune mode") g_params += list(model.parameters()) model.optimize = True else: print("Working with frozen features") for g in model.parameters(): g.requires_grad = False optimizer = torch.optim.Adam(g_params, lr=args.lr, betas=(args.beta1, args.beta2), eps=args.epsilon) # Checkpoint directory args.pathCheckpoint = Path(args.pathCheckpoint) args.pathCheckpoint.mkdir(exist_ok=True) args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint") with open(f"{args.pathCheckpoint}_args.json", 'w') as file: json.dump(vars(args), file, indent=2) run(model, criterion, train_loader, val_loader, optimizer, logs, args.n_epoch, args.pathCheckpoint)
args.pathDB = os.path.abspath(args.pathDB) if not args.load: assert os.path.exists(args.pathOutput) is False, \ f"The output file {args.pathOutput} already exists, please check the option --load !" assert os.path.exists(os.path.join(os.path.dirname(args.pathOutput), "checkpoint_last.pt")) is False, \ f"Found last_checkpoint.pt in the output directory, please check the option --load !" print(args) seqNames, speakers = findAllSeqs(args.pathDB, speaker_level=args.recursionLevel, extension=args.extension, loadCache=True) if args.seqList is not None: seqNames = filterSeqs(args.seqList, seqNames) if args.debug: nsamples = 1000 print(f"Debug mode activated, get only {nsamples} samples!") shuffle(seqNames) seqNames = seqNames[:nsamples] if args.getDistanceEstimation: shuffle(seqNames) seqNames = seqNames[:5000] print("") print(f'Loading audio data at {args.pathDB}') start_time = time.time() dataset = AudioBatchData(args.pathDB, args.sizeWindow, seqNames,
# Output Directory if not os.path.isdir(args.output): os.mkdir(args.output) name = f"_{args.name}" if args.command == "per" else "" pathLogs = os.path.join(args.output, f'logs_{args.command}{name}.txt') tee = subprocess.Popen(["tee", pathLogs], stdin=subprocess.PIPE) os.dup2(tee.stdin.fileno(), sys.stdout.fileno()) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) inSeqs, _ = findAllSeqs(args.pathDB, extension=args.file_extension) # Datasets if args.command == 'train' and args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, inSeqs) else: seqTrain = inSeqs if args.pathVal is None and args.command == 'train': random.shuffle(seqTrain) sizeTrain = int(0.9 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] elif args.pathVal is not None: seqVal = filterSeqs(args.pathVal, inSeqs) else: raise RuntimeError("No validation dataset found for PER computation") if args.debug: seqVal = seqVal[:100]
def main(args): # import ptvsd # ptvsd.enable_attach(('0.0.0.0', 7309)) # print("Attach debugger now") # ptvsd.wait_for_attach() args = parseArgs(args) utils.set_seed(args.random_seed) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} loadOptimizer = False os.makedirs(args.pathCheckpoint, exist_ok=True) if not args.onlyCapture and not args.only_classif_metric: json.dump(vars(args), open(os.path.join(args.pathCheckpoint, 'checkpoint_args.json'), 'wt')) if args.pathCheckpoint is not None and not args.restart: cdata = fl.getCheckpointData(args.pathCheckpoint) if cdata is not None: data, logs, locArgs = cdata print(f"Checkpoint detected at {data}") fl.loadArgs(args, locArgs, forbiddenAttr={"nGPU", "pathCheckpoint", "debug", "restart", "world_size", "n_nodes", "node_id", "n_gpu_per_node", "max_size_loaded"}) args.load, loadOptimizer = [data], True args.loadCriterion = True logs["logging_step"] = args.logging_step print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}') print('-' * 50) seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) if not args.onlyCapture or args.only_classif_metric: print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets if args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, seqNames) else: seqTrain = seqNames if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.99 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val') else: seqVal = filterSeqs(args.pathVal, seqNames) if args.pathCaptureDS is not None: assert args.pathCaptureSave is not None whatToSave = [] if args.captureEverything: whatToSave = ['conv_repr', 'ctx_repr', 'speaker_align', 'pred'] if args.path_phone_data: whatToSave.append('phone_align') if args.CPCCTC: whatToSave.append('cpcctc_align') whatToSave.append('cpcctc_log_scores') else: for argVal, name in zip([args.captureConvRepr, args.captureCtxRepr, args.captureSpeakerAlign, args.capturePhoneAlign, args.capturePred, args.captureCPCCTCalign, args.captureCPCCTClogScores], ['conv_repr', 'ctx_repr', 'speaker_align', 'phone_align', 'pred', 'cpcctc_align', 'cpcctc_log_scores']): if argVal: whatToSave.append(name) ###assert len(whatToSave) > 0 captureOptions = { 'path': args.pathCaptureSave, 'eachEpochs': args.captureEachEpochs, 'what': whatToSave } seqCapture = filterSeqs(args.pathCaptureDS, seqNames, percentage=args.captureDSfreq, totalNum=args.captureDStotNr) print(f'Capture files: {len(seqCapture)}') else: seqCapture = None captureOptions = None if not args.onlyCapture: if args.debug: seqTrain = seqTrain[-1000:] seqVal = seqVal[-100:] phoneLabels, nPhones = None, None if args.supervised and args.pathPhone is not None: print("Loading the phone labels at " + args.pathPhone) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) print(f"{nPhones} phones found") print("") print(f'Loading audio data at {args.pathDB}') print("Loading the training dataset") trainDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print("Training dataset loaded") print("") print("Loading the validation dataset") valDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader) print("Validation dataset loaded") print("") else: phoneLabels, nPhones = None, None trainDataset = None valDataset = None if seqCapture is not None: if args.path_phone_data: print("Loading the phone labels at " + args.path_phone_data) phoneLabelsForCapture, _ = parseSeqLabels(args.path_phone_data) else: assert not args.capturePhoneAlign phoneLabelsForCapture = None print("Loading the capture dataset") captureDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqCapture, phoneLabelsForCapture, len(speakers), nProcessLoader=args.n_process_loader) print("Capture dataset loaded") print("") if args.captureSetStats: captureSetStatsCollector = statutil.constructStatCollectorFromSpecs(args.captureSetStats) else: captureSetStatsCollector = None else: captureDataset = None captureSetStatsCollector = None if args.load is not None: if args.gru_level is not None and args.gru_level > 0: updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level) else: updateConfig = None # loadBestNotLast = args.onlyCapture or args.only_classif_metric # could use this option for loading best state when not running actual training # but relying on CPC internal acc isn't very reliable # [!] caution - because of how they capture checkpoints, # they capture "best in this part of training" as "best" (apart from capturing current state) # so if best is in epoch 100 and training is paused and resumed from checkpoint # in epoch 150, checkpoint from epoch 200 has "best from epoch 150" saved as globally best # (but this is internal-CPC-score best anyway, which is quite vague) cpcModel, args.hiddenGar, args.hiddenEncoder = \ fl.loadModel(args.load, load_nullspace=args.nullspace, updateConfig=updateConfig) CPChiddenGar, CPChiddenEncoder = args.hiddenGar, args.hiddenEncoder if args.gru_level is not None and args.gru_level > 0: # Keep hidden units at LSTM layers on sequential batches if args.nullspace: cpcModel.cpc.gAR.keepHidden = True else: cpcModel.gAR.keepHidden = True else: # Encoder network encoderNet = fl.getEncoder(args) # AR Network arNet = fl.getAR(args) cpcModel = model.CPCModel(encoderNet, arNet) CPChiddenGar, CPChiddenEncoder = cpcModel.gAR.getDimOutput(), cpcModel.gEncoder.getDimOutput() batchSize = args.nGPU * args.batchSizeGPU cpcModel.supervised = args.supervised downsampling = cpcModel.cpc.gEncoder.DOWNSAMPLING if isinstance(cpcModel, model.CPCModelNullspace) else cpcModel.gEncoder.DOWNSAMPLING # Training criterion if args.load is not None and args.loadCriterion: cpcCriterion = loadCriterion(args.load[0], downsampling, len(speakers), nPhones) else: cpcCriterion = getCriterion(args, downsampling, len(speakers), nPhones) if loadOptimizer: state_dict = torch.load(args.load[0], 'cpu') cpcCriterion.load_state_dict(state_dict["cpcCriterion"]) cpcCriterion.cuda() cpcModel.cuda() # Optimizer g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters()) lr = args.learningRate optimizer = torch.optim.Adam(g_params, lr=lr, betas=(args.beta1, args.beta2), eps=args.epsilon) if loadOptimizer and not args.onlyCapture and not args.only_classif_metric: print("Loading optimizer " + args.load[0]) state_dict = torch.load(args.load[0], 'cpu') if "optimizer" in state_dict: optimizer.load_state_dict(state_dict["optimizer"]) # Checkpoint if args.pathCheckpoint is not None and not args.onlyCapture and not args.only_classif_metric: if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") with open(args.pathCheckpoint + "_args.json", 'w') as file: json.dump(vars(args), file, indent=2) scheduler = None if args.schedulerStep > 0: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.schedulerStep, gamma=0.5) if args.schedulerRamp is not None: n_epoch = args.schedulerRamp print(f"Ramp activated. n_e = {n_epoch}") scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: utils.ramp_scheduling_function( n_epoch, epoch), last_epoch=-1) if scheduler is None: scheduler = scheduler_ramp else: scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler], [0, args.schedulerRamp]) if scheduler is not None: print(f'Redoing {len(logs["epoch"])} scheduler steps') for i in range(len(logs["epoch"])): scheduler.step() print("cpcModel", cpcModel) print("cpcCriterion", cpcCriterion) cpcModel = torch.nn.DataParallel(cpcModel, device_ids=range(args.nGPU)).cuda() cpcCriterion = torch.nn.DataParallel(cpcCriterion, device_ids=range(args.nGPU)).cuda() if args.supervised_classif_metric: linsep_batch_size = args.linsepBatchSizeGPU * args.nGPU dim_features = CPChiddenEncoder if args.phone_get_encoded else CPChiddenGar dim_ctx_features = CPChiddenGar # for speakers using CNN encodings is not supported; could add but not very useful perhaps phoneLabelsData = None if args.path_phone_data: phoneLabelsData, nPhonesInData = parseSeqLabels(args.path_phone_data) if not args.CTCphones: print(f"Running phone separability with aligned phones") else: print(f"Running phone separability with CTC loss") def constructPhoneCriterionAndOptimizer(): if not args.CTCphones: # print(f"Running phone separability with aligned phones") phone_criterion = cr.PhoneCriterion(dim_features, nPhonesInData, args.phone_get_encoded, nLayers=args.linsep_net_layers) else: # print(f"Running phone separability with CTC loss") phone_criterion = cr.CTCPhoneCriterion(dim_features, nPhonesInData, args.phone_get_encoded, nLayers=args.linsep_net_layers) phone_criterion.cuda() phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU)) # Optimizer phone_g_params = list(phone_criterion.parameters()) phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr, betas=(args.linsep_beta1, args.linsep_beta2), eps=args.linsep_epsilon) return phone_criterion, phone_optimizer if args.speaker_sep: print(f"Running speaker separability") def constructSpeakerCriterionAndOptimizer(): speaker_criterion = cr.SpeakerCriterion(dim_ctx_features, len(speakers), nLayers=args.linsep_net_layers) speaker_criterion.cuda() speaker_criterion = torch.nn.DataParallel(speaker_criterion, device_ids=range(args.nGPU)) speaker_g_params = list(speaker_criterion.parameters()) speaker_optimizer = torch.optim.Adam(speaker_g_params, lr=args.linsep_lr, betas=(args.linsep_beta1, args.linsep_beta2), eps=args.linsep_epsilon) return speaker_criterion, speaker_optimizer linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabelsData, len(speakers)) linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabelsData, len(speakers)) linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True, numWorkers=0) linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False, numWorkers=0) def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch): log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch)) if not os.path.exists(log_path_for_epoch): os.makedirs(log_path_for_epoch) log_path_phoneme = os.path.join(log_path_for_epoch, "phoneme/") log_path_speaker = os.path.join(log_path_for_epoch, "speaker/") if not os.path.exists(log_path_phoneme): os.makedirs(log_path_phoneme) if not os.path.exists(log_path_speaker): os.makedirs(log_path_speaker) if args.linsep_checkpoint_dir: checpoint_path_for_epoch = os.path.join(args.linsep_checkpoint_dir, str(numOfEpoch)) checkpoint_path_phoneme = os.path.join(checpoint_path_for_epoch, "phoneme/") checkpoint_path_speaker = os.path.join(checpoint_path_for_epoch, "speaker/") if not os.path.exists(checkpoint_path_phoneme): os.makedirs(checkpoint_path_phoneme) if not os.path.exists(checkpoint_path_speaker): os.makedirs(checkpoint_path_speaker) locLogsPhone = {} locLogsSpeaker = {} if args.path_phone_data: phone_criterion, phone_optimizer = constructPhoneCriterionAndOptimizer() locLogsPhone = linsep.trainLinsepClassification( cpcMdl, phone_criterion, # combined with classification model before linsep_train_loader, linsep_val_loader, phone_optimizer, log_path_phoneme, args.linsep_task_logging_step, checkpoint_path_phoneme, args.linsep_n_epoch, cpcStateEpoch, 'phone') del phone_criterion del phone_optimizer if args.speaker_sep: speaker_criterion, speaker_optimizer = constructSpeakerCriterionAndOptimizer() locLogsSpeaker = linsep.trainLinsepClassification( cpcMdl, speaker_criterion, # combined with classification model before linsep_train_loader, linsep_val_loader, speaker_optimizer, log_path_speaker, args.linsep_task_logging_step, checkpoint_path_speaker, args.linsep_n_epoch, cpcStateEpoch, 'speaker') del speaker_criterion del speaker_optimizer locLogsPhone = {"phone_" + k: v for k, v in locLogsPhone.items()} locLogsSpeaker = {"speaker_" + k: v for k, v in locLogsSpeaker.items()} return {**locLogsPhone, **locLogsSpeaker} linsepClassificationTaskConfig = (args.linsep_classif_each_epochs, runLinsepClassificationTraining) else: linsepClassificationTaskConfig = (None, None) if not args.onlyCapture and not args.only_classif_metric: run(trainDataset, valDataset, (captureDataset, captureOptions, captureSetStatsCollector), linsepClassificationTaskConfig, batchSize, args.samplingType, cpcModel, cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler, logs) if args.onlyCapture: # caution [!] - will capture for last checkpoint (last saved state) if checkpoint directory given # to use specific checkpoint provide full checkpoint file path # will use "last state" and not "best in internal CPC accuracy" anyway onlyCapture( (captureDataset, captureOptions, captureSetStatsCollector), batchSize, cpcModel, cpcCriterion, logs) if args.only_classif_metric: # caution [!] - will use last checkpoint (last saved state) if checkpoint directory given # to use specific checkpoint provide full checkpoint file path # will use "last state" and not "best in internal CPC accuracy" anyway trainedEpoch = len(logs["epoch"]) - 1 # runPhonemeClassificationTraining created above if args.supervised_classif_metric runLinsepClassificationTraining(trainedEpoch, cpcModel, trainedEpoch)
def main(pathCheckpoint, pathDB, pathOutputDir, batch_size=8, debug=False, file_extension='.wav', layer='all', max_size_seq=64000, output_file_extension='.txt', recursionLevel=2, seqList=None, audio_features_fn='mfcc_features.pt', image_features_fn='resnet_features.pt', cpc_model_path=None, cpc_gru_level=-1, zr_format=False): args = argparse.Namespace(**locals()) print("=============================================================") print(f"Extract activations from VG model for {pathDB}") print("=============================================================") # Initializing feature extraction config # /!\ Code duplication with preprocessing.py # Should probably store the feature config on disk. _audio_feat_config = dict(type='mfcc', delta=True, alpha=0.97, n_filters=40, window_size=0.025, frame_shift=0.010, audio_features_fn=audio_features_fn) _images_feat_config = dict(model='resnet', image_features_fn=image_features_fn) if cpc_model_path is not None: if audio_features_fn == 'mfcc_features.pt': audio_features_fn = 'cpc_features.pt' _audio_feat_config = dict(type='cpc', model_path=cpc_model_path, audio_features_fn=audio_features_fn, strict=False, seq_norm=False, max_size_seq=10240, gru_level=cpc_gru_level, on_gpu=True) # Find all sequences print("") print(f"Looking for all {file_extension} files in {pathDB}") seqNames, _ = findAllSeqs(pathDB, speaker_level=recursionLevel, extension=file_extension, loadCache=True) if len(seqNames) == 0 or not os.path.splitext(seqNames[0][-1])[1].endswith(file_extension): print("Seems like the _seq_cache.txt does not contain the correct extension, reload the file list") seqNames, _ = findAllSeqs(pathDB, speaker_level=recursionLevel, extension=file_extension, loadCache=False) print(f"Done! Found {len(seqNames)} files!") # Filter specific sequences if seqList is not None: seqNames = filterSeqs(seqList, seqNames) print(f"Done! {len(seqNames)} remaining files after filtering!") assert len(seqNames) > 0, \ "No file to be processed!" pathOutputDir = Path(pathOutputDir) print("") print(f"Creating the output directory at {pathOutputDir}") pathOutputDir.mkdir(parents=True, exist_ok=True) writeArgs(pathOutputDir / "_info_args.json", args) # Debug mode if debug: nsamples = 20 print("") print(f"Debug mode activated, only load {nsamples} samples!") # shuffle(seqNames) seqNames = seqNames[:nsamples] # Loading audio features print("") print(f"Loading audio features for {pathDB}") pathDB = Path(pathDB) if seqList is None: cache_fpath = pathDB / args.audio_features_fn if cache_fpath.exists(): print(f"Found cached features ({cache_fpath}). Loading them.") features = torch.load(cache_fpath) else: print('No cached features. Computing them from scratch.') audio_fpaths = [pathDB / s[1] for s in seqNames] features = compute_audio_features(audio_fpaths, max_size_seq, _audio_feat_config) print(f'Caching features ({cache_fpath}).') torch.save(features, cache_fpath) else: print('Computing features.') audio_fpaths = [pathDB / s[1] for s in seqNames] features = compute_audio_features(audio_fpaths, max_size_seq, _audio_feat_config) # Load VG model print("") print(f"Loading VG model from {pathCheckpoint}") vg_model = torch.load(pathCheckpoint) print("VG model loaded!") # Extracting activations print("") print(f"Extracting activations and saving outputs to {pathOutputDir}...") data = torch.utils.data.DataLoader(dataset=features, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: dataset.batch_audio(x, max_frames=None)) i_next = 0 zr_keywords = ['phonetic', 'lexical', 'syntactic', 'semantic'] if zr_format: splitted_path = str(pathDB).split('/') for keyword in zr_keywords: if keyword in splitted_path: keyword_idx = splitted_path.index(keyword) break suffix = '/'.join(splitted_path[keyword_idx:]) else: suffix = "" for au, l in tqdm(data): activations = vg_model.SpeechEncoder.introspect(au.cuda(), l.cuda()) fnames = [s[1] for s in seqNames[i_next: i_next + batch_size]] if layer == 'all': for k in activations: save_activations(activations[k], pathOutputDir / k / suffix, fnames, output_file_extension) elif layer in activations: save_activations(activations[layer], pathOutputDir / layer / suffix, fnames, output_file_extension) i_next += batch_size
def main(args): args = parseArgs(args) utils.set_seed(args.random_seed) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} loadOptimizer = False if args.pathCheckpoint is not None and not args.restart: cdata = fl.getCheckpointData(args.pathCheckpoint) if cdata is not None: data, logs, locArgs = cdata print(f"Checkpoint detected at {data}") fl.loadArgs(args, locArgs, forbiddenAttr={ "nGPU", "pathCheckpoint", "debug", "restart", "world_size", "n_nodes", "node_id", "n_gpu_per_node", "max_size_loaded" }) args.load, loadOptimizer = [data], True args.loadCriterion = True logs["logging_step"] = args.logging_step print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}') print('-' * 50) seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets if args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, seqNames) else: seqTrain = seqNames if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.99 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val') else: seqVal = filterSeqs(args.pathVal, seqNames) if args.debug: seqTrain = seqTrain[-1000:] seqVal = seqVal[-100:] phoneLabels, nPhones = None, None if args.supervised and args.pathPhone is not None: print("Loading the phone labels at " + args.pathPhone) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) print(f"{nPhones} phones found") print("") print(f'Loading audio data at {args.pathDB}') print("Loading the training dataset") trainDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print("Training dataset loaded") print("") print("Loading the validation dataset") valDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader) print("Validation dataset loaded") print("") if args.load is not None: cpcModel, args.hiddenGar, args.hiddenEncoder = \ fl.loadModel(args.load) else: # Encoder network encoderNet = fl.getEncoder(args) # AR Network arNet = fl.getAR(args) cpcModel = model.CPCModel(encoderNet, arNet) batchSize = args.nGPU * args.batchSizeGPU cpcModel.supervised = args.supervised # Training criterion if args.load is not None and args.loadCriterion: cpcCriterion = loadCriterion(args.load[0], cpcModel.gEncoder.DOWNSAMPLING, len(speakers), nPhones) else: cpcCriterion = getCriterion(args, cpcModel.gEncoder.DOWNSAMPLING, len(speakers), nPhones) if loadOptimizer: state_dict = torch.load(args.load[0], 'cpu') cpcCriterion.load_state_dict(state_dict["cpcCriterion"]) cpcCriterion.cuda() cpcModel.cuda() # Optimizer g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters()) lr = args.learningRate optimizer = torch.optim.Adam(g_params, lr=lr, betas=(args.beta1, args.beta2), eps=args.epsilon) if loadOptimizer: print("Loading optimizer " + args.load[0]) state_dict = torch.load(args.load[0], 'cpu') if "optimizer" in state_dict: optimizer.load_state_dict(state_dict["optimizer"]) # Checkpoint if args.pathCheckpoint is not None: if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") scheduler = None if args.schedulerStep > 0: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.schedulerStep, gamma=0.5) if args.schedulerRamp is not None: n_epoch = args.schedulerRamp print(f"Ramp activated. n_e = {n_epoch}") scheduler_ramp = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: utils.ramp_scheduling_function( n_epoch, epoch), last_epoch=-1) if scheduler is None: scheduler = scheduler_ramp else: scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler], [0, args.schedulerRamp]) if scheduler is not None: for i in range(len(logs["epoch"])): scheduler.step() cpcModel = torch.nn.DataParallel(cpcModel, device_ids=range(args.nGPU)).cuda() cpcCriterion = torch.nn.DataParallel(cpcCriterion, device_ids=range(args.nGPU)).cuda() run(trainDataset, valDataset, batchSize, args.samplingType, cpcModel, cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler, logs)
def main(pathActivations, pathOutput, nGroups=1, nClusters=50, MAX_ITER=100, batchSizeGPU=50, debug=False, extension='.pt', getDistanceEstimation=False, load=False, perIterSize=-1, recursionLevel=2, save=False, save_last=5, seqList=None): # Test the extension is valid if extension not in ['.txt', '.npy', '.pt']: raise ValueError(f'Activation file extension invalid ({extension})') torch.cuda.empty_cache() args = argparse.Namespace(**locals()) # Export absolute paths for later use pathActivations = os.path.abspath(pathActivations) pathOutput = os.path.abspath(pathOutput) if not load: assert os.path.exists(pathOutput) is False, \ f"The output file {pathOutput} already exists, please check the option --load !" assert os.path.exists(os.path.join(os.path.dirname(pathOutput), "checkpoint_last.pt")) is False, \ "Found last_checkpoint.pt in the output directory, please check the option --load !" print(args) seqNames, speakers = findAllSeqs(pathActivations, speaker_level=recursionLevel, extension=extension, loadCache=True) if seqList is not None: seqNames = filterSeqs(seqList, seqNames) if debug: nsamples = 1000 print(f"Debug mode activated, get only {nsamples} samples!") shuffle(seqNames) seqNames = seqNames[:nsamples] if getDistanceEstimation: shuffle(seqNames) seqNames = seqNames[:5000] print("") print(f'Loading activations at {pathActivations}') start_time = time.time() dataset = SequentialData(pathActivations, seqNames, None) print(f"Dataset loaded in {time.time()-start_time} seconds !") print("") nGPUs = torch.cuda.device_count() if nGPUs == 0: raise RuntimeError('No GPU found') batchSize = batchSizeGPU * nGPUs dataloader = dataset.getDataLoader(batchSize, numWorkers=0) print(f"Length of dataLoader: {len(dataloader)}") print("") # Check if dir exists if not os.path.exists(os.path.dirname(pathOutput)) and os.path.dirname(pathOutput): Path(os.path.dirname(pathOutput)).mkdir(parents=True, exist_ok=True) pathConfig = f"{os.path.splitext(pathOutput)[0]}_args.json" with open(pathConfig, 'w') as file: json.dump(vars(args), file, indent=2) out_state_dict = {} print("Starting the clustering...") start_time = time.time() # Using a dumb lambda function to skip feature extraction as we start from # the activations clusters = kMeanGPU(dataloader, lambda x: x, nClusters, nGroups, perIterSize=perIterSize, MAX_ITER=MAX_ITER, save=save, load=load, save_dir=os.path.dirname(pathOutput), save_last=save_last, ).cpu() print(f'Ran clustering ' f'in {time.time() - start_time:.2f} seconds') clusterModule = kMeanCluster(clusters) out_state_dict["state_dict"] = clusterModule.state_dict() out_state_dict["n_clusters"] = nClusters out_state_dict['dim'] = clusters.size(2) torch.save(out_state_dict, pathOutput) with open(pathConfig, 'w') as file: json.dump(vars(args), file, indent=2)