def loadCPCFeatureMaker(pathCheckpoint, gru_level=-1, get_encoded=False, keep_hidden=True, load_nullspace=False): """ Load CPC Feature Maker from CPC checkpoint file. """ # Set LSTM level if gru_level is not None and gru_level > 0: updateConfig = argparse.Namespace(nLevelsGRU=gru_level) else: updateConfig = None # Load CPC model model, nHiddenGar, nHiddenEncoder = loadModel( [pathCheckpoint], updateConfig=updateConfig, load_nullspace=load_nullspace) # Keep hidden units at LSTM layers on sequential batches if load_nullspace: model.cpc.gAR.keepHidden = keep_hidden else: model.gAR.keepHidden = keep_hidden # Build CPC Feature Maker from CPC model featureMaker = FeatureModule(model, get_encoded=get_encoded) return featureMaker
def main(argv): args = parse_args(argv) if args.load == 'from_checkpoint': # Checkpoint model = loadModel([args.path_checkpoint])[0] model.gAR.keepHidden = True # Feature maker feature_maker = FeatureModule(model, args.get_encoded).cuda().eval() def feature_function(x): return buildFeature(feature_maker, x, seqNorm=args.seq_norm, strict=args.strict, maxSizeSeq=args.max_size_seq) elif args.load == 'from_pre_computed': def feature_function(x): return torch.load(x, 'cpu') # Modes if args.mode == 'all': modes = ["within", "across"] else: modes = [args.mode] distance_mode = 'cosine' step_feature = 1 / args.feature_size # Get the list of sequences seq_list, _ = findAllSeqs(args.path_dataset, extension=args.file_extension) seq_list = [(str(Path(x).stem), str(Path(args.path_dataset) / x)) for (_, x) in seq_list] if args.debug: seq_list = seq_list[:1000] scores = ABX(feature_function, args.path_item_file, seq_list, distance_mode, step_feature, modes, cuda=args.cuda, seq_norm=args.seq_norm, max_x_across=args.max_x_across, max_size_group=args.max_size_group) out_dir = Path(args.path_checkpoint).parent if args.out is None \ else Path(args.out) out_dir.mkdir(exist_ok=True) path_score = out_dir / 'ABX_scores.json' with open(path_score, 'w') as file: json.dump(scores, file, indent=2) path_args = out_dir / 'ABX_args.json' with open(path_args, 'w') as file: json.dump(vars(args), file, indent=2)
randomOffset=False, numWorkers=0) db_test = LibriSelectionDataset(sizeWindow=20480, db_wav_root=DB_WAV_ROOT, fps_list=TS_LIST_PATH, label_path=LABEL_PATH, n_process_loader=8, MAX_SIZE_LOADED=400000000) test_loader = db_test.getDataLoader( batchSize=128, type='sequential', #'sequential', randomOffset=False, numWorkers=0) """Load model: c, z, label = feat_gen(raw_audio, label)""" feat_gen, d_c, d_z = fl.loadModel([CPC_CHECKPOINT_PATH], loadStateDict=True) if nGPU > 0: feat_gen = feat_gen.cuda() feat_gen = torch.nn.DataParallel(feat_gen, device_ids=range(nGPU)) feat_gen.optimize = False feat_gen.eval() for g in feat_gen.parameters(): g.requires_grad = False """Create classifier: clf()""" d_feat = int('c' in SEL_FEAT) * d_c + int('z' in SEL_FEAT) * d_z n_classes = len(db_train.speakers) #clf = MLP(d_feat, 2048, n_classes) clf = SpeakerClf(d_feat, n_classes) if nGPU > 0: clf = clf.cuda()
def main(argv): args = parse_args(argv) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} load_criterion = False seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) if args.model == "cpc": def loadCPCFeatureMaker(pathCheckpoint, gru_level=-1, get_encoded=False, keep_hidden=True): """ Load CPC Feature Maker from CPC checkpoint file. """ # Set LSTM level if gru_level is not None and gru_level > 0: updateConfig = argparse.Namespace(nLevelsGRU=gru_level) else: updateConfig = None # Load CPC model model, nHiddenGar, nHiddenEncoder = fl.loadModel(pathCheckpoint, updateConfig=updateConfig) # Keep hidden units at LSTM layers on sequential batches model.gAR.keepHidden = keep_hidden # Build CPC Feature Maker from CPC model #featureMaker = fl.FeatureModule(model, get_encoded=get_encoded) #return featureMaker return model, nHiddenGar, nHiddenEncoder if args.gru_level is not None and args.gru_level > 0: model, hidden_gar, hidden_encoder = loadCPCFeatureMaker(args.load, gru_level=args.gru_level) else: model, hidden_gar, hidden_encoder = fl.loadModel(args.load, loadStateDict=not args.no_pretraining) dim_features = hidden_encoder if args.get_encoded else hidden_gar else: sys.path.append(os.path.abspath(args.path_fairseq)) from fairseq import checkpoint_utils def loadCheckpoint(path_checkpoint, path_data): """ Load lstm_lm model from checkpoint. """ # Set up the args Namespace model_args = argparse.Namespace( task="language_modeling", output_dictionary_size=-1, data=path_data, path=path_checkpoint ) # Load model models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path]) model = models[0] return model model = loadCheckpoint(args.load[0], args.pathDB) dim_features = 768 dim_inter = args.dim_inter # Now the criterion if args.mode == "phonemes_nullspace" or args.mode == "speakers_nullspace": speakers_factorized = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers)) speakers_factorized.load_state_dict(torch.load(args.path_speakers_factorized)["cpcCriterion"]) for param in speakers_factorized.parameters(): param.requires_grad = False def my_nullspace(At, rcond=None): ut, st, vht = torch.Tensor.svd(At, some=False,compute_uv=True) vht=vht.T Mt, Nt = ut.shape[0], vht.shape[1] if rcond is None: rcondt = torch.finfo(st.dtype).eps * max(Mt, Nt) tolt = torch.max(st) * rcondt numt= torch.sum(st > tolt, dtype=int) nullspace = vht[numt:,:].T.cpu().conj() # nullspace.backward(torch.ones_like(nullspace),retain_graph=True) return nullspace dim_features = dim_features - dim_inter nullspace = my_nullspace(speakers_factorized.linearSpeakerClassifier[0].weight) model = CPCModelNullspace(model, nullspace) phone_labels = None if args.pathPhone is not None: phone_labels, n_phones = parseSeqLabels(args.pathPhone) label_key = 'phone' if not args.CTC: print(f"Running phone separability with aligned phones") criterion = cr.PhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running phone separability with CTC loss") criterion = cr.CTCPhoneCriterion(dim_features, n_phones, args.get_encoded) else: label_key = 'speaker' print(f"Running speaker separability") if args.mode == "speakers_factorized": criterion = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers)) else: criterion = cr.SpeakerCriterion(dim_features, len(speakers)) criterion.cuda() criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU)) model.cuda() model = torch.nn.DataParallel(model, device_ids=range(args.nGPU)) # Dataset seq_train = filterSeqs(args.pathTrain, seqNames) seq_val = filterSeqs(args.pathVal, seqNames) if args.debug: seq_train = seq_train[:1000] seq_val = seq_val[:100] db_train = AudioBatchData(args.pathDB, args.size_window, seq_train, phone_labels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) db_val = AudioBatchData(args.pathDB, args.size_window, seq_val, phone_labels, len(speakers), nProcessLoader=args.n_process_loader) batch_size = args.batchSizeGPU * args.nGPU train_loader = db_train.getDataLoader(batch_size, "uniform", True, numWorkers=0) val_loader = db_val.getDataLoader(batch_size, 'sequential', False, numWorkers=0) # Optimizer g_params = list(criterion.parameters()) model.optimize = False model.eval() if args.unfrozen: print("Working in full fine-tune mode") g_params += list(model.parameters()) model.optimize = True else: print("Working with frozen features") for g in model.parameters(): g.requires_grad = False optimizer = torch.optim.Adam(g_params, lr=args.lr, betas=(args.beta1, args.beta2), eps=args.epsilon) # Checkpoint directory args.pathCheckpoint = Path(args.pathCheckpoint) args.pathCheckpoint.mkdir(exist_ok=True) args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint") with open(f"{args.pathCheckpoint}_args.json", 'w') as file: json.dump(vars(args), file, indent=2) if args.centerpushFile: clustersFileExt = args.centerpushFile.split('.')[-1] assert clustersFileExt in ('pt', 'npy', 'txt') if clustersFileExt == 'npy': centers = np.load(args.centerpushFile) elif clustersFileExt == 'txt': centers = np.genfromtxt(args.centerpushFile) elif clustersFileExt == 'pt': # assuming it's a checkpoint centers = torch.load(args.centerpushFile, map_location=torch.device('cpu'))['state_dict']['Ck'] centers = torch.reshape(centers, centers.shape[1:]).numpy() centers = torch.tensor(centers).cuda() centerpushSettings = (centers, args.centerpushDeg) else: centerpushSettings = None run(model, criterion, train_loader, val_loader, optimizer, logs, args.n_epoch, args.pathCheckpoint, label_key=label_key, centerpushSettings=centerpushSettings)
def main(argv): # Args parser args = parseArgs(argv) print("=============================================================") print(f"Quantizing data from {args.pathDB}") print("=============================================================") # Check if directory exists if not os.path.exists(args.pathOutput): print("") print(f"Creating the output directory at {args.pathOutput}") Path(args.pathOutput).mkdir(parents=True, exist_ok=True) # Get splits if args.split: assert len(args.split.split("-"))==2 and int(args.split.split("-")[1]) >= int(args.split.split("-")[0]) >= 1, \ "SPLIT must be under the form idxSplit-numSplits (numSplits >= idxSplit >= 1), eg. --split 1-20" idx_split, num_splits = args.split.split("-") idx_split = int(idx_split) num_splits = int(num_splits) # Find all sequences print("") print(f"Looking for all {args.file_extension} files in {args.pathDB} with speakerLevel {args.recursionLevel}") seqNames, speakers = findAllSeqs(args.pathDB, speaker_level=args.recursionLevel, extension=args.file_extension, loadCache=True) if args.pathSeq: with open(args.pathSeq, 'r') as f: seqs = set([x.strip() for x in f]) filtered = [] for s in seqNames: if s[1].split('/')[-1].split('.')[0] in seqs: filtered.append(s) seqNames = filtered print(f"Done! Found {len(seqNames)} files and {len(speakers)} speakers!") if args.separate_speaker: seqNames_by_speaker = {} for seq in seqNames: speaker = seq[1].split("/")[args.recursionLevel-1] if speaker not in seqNames_by_speaker: seqNames_by_speaker[speaker] = [] seqNames_by_speaker[speaker].append(seq) # Check if output file exists if not args.split: nameOutput = "quantized_outputs.txt" else: nameOutput = f"quantized_outputs_split_{idx_split}-{num_splits}.txt" if args.separate_speaker is False: outputFile = os.path.join(args.pathOutput, nameOutput) assert not os.path.exists(outputFile), \ f"Output file {outputFile} already exists !!!" # Get splits if args.split: startIdx = len(seqNames) // num_splits * (idx_split-1) if idx_split == num_splits: endIdx = len(seqNames) else: endIdx = min(len(seqNames) // num_splits * idx_split, len(seqNames)) seqNames = seqNames[startIdx:endIdx] print("") print(f"Quantizing split {idx_split} out of {num_splits} splits, with {len(seqNames)} files (idx in range({startIdx}, {endIdx})).") # Debug mode if args.debug: nsamples=20 print("") print(f"Debug mode activated, only load {nsamples} samples!") # shuffle(seqNames) seqNames = seqNames[:nsamples] # Load Clustering args assert args.pathCheckpoint[-3:] == ".pt" if os.path.exists(args.pathCheckpoint[:-3] + "_args.json"): pathConfig = args.pathCheckpoint[:-3] + "_args.json" elif os.path.exists(os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json")): pathConfig = os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json") else: assert False, \ f"Args file not found in the directory {os.path.dirname(args.pathCheckpoint)}" clustering_args = readArgs(pathConfig) print("") print(f"Clutering args:\n{json.dumps(vars(clustering_args), indent=4, sort_keys=True)}") print('-' * 50) # Load CluterModule clusterModule = loadClusterModule(args.pathCheckpoint, norm_vec_len=args.norm_vec_len) clusterModule.cuda() # Load FeatureMaker print("") print("Loading CPC FeatureMaker") if 'level_gru' in vars(clustering_args) and clustering_args.level_gru is not None: updateConfig = argparse.Namespace(nLevelsGRU=clustering_args.level_gru) else: updateConfig = None model = loadModel([clustering_args.pathCheckpoint], updateConfig=updateConfig)[0] ## If we don't apply batch implementation, we can set LSTM model to keep hidden units ## making the quality of the quantized units better if args.nobatch: model.gAR.keepHidden = True featureMaker = FeatureModule(model, clustering_args.encoder_layer) if clustering_args.dimReduction is not None: dimRed = loadDimReduction(clustering_args.dimReduction, clustering_args.centroidLimits) featureMaker = torch.nn.Sequential(featureMaker, dimRed) if not clustering_args.train_mode: featureMaker.eval() featureMaker.cuda() def feature_function(x): if args.nobatch is False: res0 = buildFeature_batch(featureMaker, x, seqNorm=False, strict=args.strict, maxSizeSeq=args.max_size_seq, batch_size=args.batch_size) if args.norm_vec_len: # [!] we actually used CPC_audio/scripts/quantize_audio.py for that in the end res0Lengths = torch.sqrt((res0*res0).sum(2)) res0 = res0 / res0Lengths.view(*(res0Lengths.shape), 1) return res0 else: res0 = buildFeature(featureMaker, x, seqNorm=False, strict=args.strict) if args.norm_vec_len: # [!] we actually used CPC_audio/scripts/quantize_audio.py for that in the end res0Lengths = torch.sqrt((res0*res0).sum(2)) res0 = res0 / res0Lengths.view(*(res0Lengths.shape), 1) return res0 print("CPC FeatureMaker loaded!") # Quantization of files print("") print(f"Quantizing audio files...") seqQuantLines = [] bar = progressbar.ProgressBar(maxval=len(seqNames)) bar.start() start_time = time() for index, vals in enumerate(seqNames): bar.update(index) file_path = vals[1] file_path = os.path.join(args.pathDB, file_path) # Get features & quantizing cFeatures = feature_function(file_path).cuda() nGroups = cFeatures.size(-1)//clusterModule.Ck.size(-1) cFeatures = cFeatures.view(1, -1, clusterModule.Ck.size(-1)) if len(vals) > 2 and int(vals[-1]) > 9400000: # Librilight, to avoid OOM clusterModule = clusterModule.cpu() cFeatures = cFeatures.cpu() qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1) clusterModule = clusterModule.cuda() else: qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1) qFeatures = qFeatures[0].detach().cpu().numpy() # Transform to quantized line quantLine = ",".join(["-".join([str(i) for i in item]) for item in qFeatures.reshape(-1, nGroups)]) seqQuantLines.append(quantLine) bar.finish() print(f"...done {len(seqQuantLines)} files in {time()-start_time} seconds.") # Saving outputs print("") print(f"Saving outputs to {outputFile}") outLines = [] for vals, quantln in zip(seqNames, seqQuantLines): file_path = vals[1] file_name = os.path.splitext(os.path.basename(file_path))[0] outLines.append("\t".join([file_name, quantln])) with open(outputFile, "w") as f: f.write("\n".join(outLines))
nGPUs = torch.cuda.device_count() batchSize = args.batchSizeGPU * nGPUs trainLoader = dataset.getDataLoader(batchSize, "uniform", False, numWorkers=0) print(f"Length of dataLoader: {len(trainLoader)}") print("") if args.level_gru is None: updateConfig = None else: updateConfig = argparse.Namespace(nLevelsGRU=args.level_gru) model = loadModel([args.pathCheckpoint], updateConfig=updateConfig, load_nullspace=args.nullspace)[0] #model = loadModel([args.pathCheckpoint])[0]#, updateConfig=updateConfig)[0] featureMaker = FeatureModule(model, args.encoder_layer) print("Checkpoint loaded!") print("") if not args.train_mode: featureMaker.eval() featureMaker.cuda() # Check if dir exists if not os.path.exists(os.path.dirname( args.pathOutput)) and os.path.dirname(args.pathOutput): Path(os.path.dirname(args.pathOutput)).mkdir(parents=True,
def main(argv): args = parse_args(argv) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} load_criterion = False seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) model, hidden_gar, hidden_encoder = fl.loadModel( args.load, loadStateDict=not args.no_pretraining) model.cuda() model = torch.nn.DataParallel(model, device_ids=range(args.nGPU)) dim_features = hidden_encoder if args.get_encoded else hidden_gar # Now the criterion phone_labels = None if args.pathPhone is not None: phone_labels, n_phones = parseSeqLabels(args.pathPhone) if not args.CTC: print(f"Running phone separability with aligned phones") criterion = cr.PhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running phone separability with CTC loss") criterion = cr.CTCPhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running speaker separability") criterion = cr.SpeakerCriterion(dim_features, len(speakers)) criterion.cuda() criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU)) # Dataset seq_train = filterSeqs(args.pathTrain, seqNames) seq_val = filterSeqs(args.pathVal, seqNames) if args.debug: seq_train = seq_train[:1000] seq_val = seq_val[:100] db_train = AudioBatchData(args.pathDB, args.size_window, seq_train, phone_labels, len(speakers)) db_val = AudioBatchData(args.pathDB, args.size_window, seq_val, phone_labels, len(speakers)) batch_size = args.batchSizeGPU * args.nGPU train_loader = db_train.getDataLoader(batch_size, "uniform", True, numWorkers=0) val_loader = db_val.getDataLoader(batch_size, 'sequential', False, numWorkers=0) # Optimizer g_params = list(criterion.parameters()) model.optimize = False model.eval() if args.unfrozen: print("Working in full fine-tune mode") g_params += list(model.parameters()) model.optimize = True else: print("Working with frozen features") for g in model.parameters(): g.requires_grad = False optimizer = torch.optim.Adam(g_params, lr=args.lr, betas=(args.beta1, args.beta2), eps=args.epsilon) # Checkpoint directory args.pathCheckpoint = Path(args.pathCheckpoint) args.pathCheckpoint.mkdir(exist_ok=True) args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint") with open(f"{args.pathCheckpoint}_args.json", 'w') as file: json.dump(vars(args), file, indent=2) run(model, criterion, train_loader, val_loader, optimizer, logs, args.n_epoch, args.pathCheckpoint)
checkpoint_url = 'https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints/60k_epoch4-d0f474de.pt' checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, progress=False, map_location="cuda:0") loadArgs(locArgs, argparse.Namespace(**checkpoint["config"])) encoderNet = getEncoder(locArgs) arNet = getAR(locArgs) model = CPCModel(encoderNet, arNet) if not args.no_pretraining: model.load_state_dict(checkpoint["weights"], strict=False) feature_maker = model hiddenGar = locArgs.hiddenGar print(feature_maker, hiddenGar) print() else: feature_maker, hiddenGar, _ = loadModel( [args.pathCheckpoint], loadStateDict=not args.no_pretraining) feature_maker.cuda() feature_maker = torch.nn.DataParallel(feature_maker) phone_criterion = CTCphone_criterion(hiddenGar, nPhones, args.LSTM, seqNorm=args.seqNorm, dropout=args.dropout, reduction=args.loss_reduction) phone_criterion.cuda() phone_criterion = torch.nn.DataParallel(phone_criterion) print(f"Loading the validation dataset at {args.pathDB}") datasetVal = SingleSequenceDataset(args.pathDB, seqVal,
args = parser.parse_args() if not os.path.isdir(args.pathOut): os.mkdir(args.pathOut) with open(os.path.join(os.path.dirname(args.pathOut), f"{os.path.basename(args.pathOut)}.json"), 'w') \ as file: json.dump(vars(args), file, indent=2) outData = [ x[1] for x in findAllSeqs( args.pathDB, extension=args.extension, loadCache=False)[0] ] featureMaker = loadModel([args.pathCheckpoint])[0] stepSize = featureMaker.gEncoder.DOWNSAMPLING / 16000 print(f"stepSize : {stepSize}") featureMaker = FeatureModule(featureMaker, args.getEncoded) featureMaker.collapse = False if args.addCriterion: criterion, nPhones = loadSupervisedCriterion(args.pathCheckpoint) featureMaker = ModelPhoneCombined(featureMaker, criterion, nPhones, args.oneHot) if device == "cuda": featureMaker = featureMaker.cuda(device=0) if not args.train_mode: featureMaker.eval()
def __init__(self, model_path, intermediate_idx=0): super().__init__() self.model = loadModel([model_path], intermediate_idx=intermediate_idx)[0] self.model.gAR.keepHidden = True
def main(args): # import ptvsd # ptvsd.enable_attach(('0.0.0.0', 7309)) # print("Attach debugger now") # ptvsd.wait_for_attach() args = parseArgs(args) utils.set_seed(args.random_seed) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} loadOptimizer = False os.makedirs(args.pathCheckpoint, exist_ok=True) if not args.onlyCapture and not args.only_classif_metric: json.dump(vars(args), open(os.path.join(args.pathCheckpoint, 'checkpoint_args.json'), 'wt')) if args.pathCheckpoint is not None and not args.restart: cdata = fl.getCheckpointData(args.pathCheckpoint) if cdata is not None: data, logs, locArgs = cdata print(f"Checkpoint detected at {data}") fl.loadArgs(args, locArgs, forbiddenAttr={"nGPU", "pathCheckpoint", "debug", "restart", "world_size", "n_nodes", "node_id", "n_gpu_per_node", "max_size_loaded"}) args.load, loadOptimizer = [data], True args.loadCriterion = True logs["logging_step"] = args.logging_step print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}') print('-' * 50) seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) if not args.onlyCapture or args.only_classif_metric: print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets if args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, seqNames) else: seqTrain = seqNames if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.99 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val') else: seqVal = filterSeqs(args.pathVal, seqNames) if args.pathCaptureDS is not None: assert args.pathCaptureSave is not None whatToSave = [] if args.captureEverything: whatToSave = ['conv_repr', 'ctx_repr', 'speaker_align', 'pred'] if args.path_phone_data: whatToSave.append('phone_align') if args.CPCCTC: whatToSave.append('cpcctc_align') whatToSave.append('cpcctc_log_scores') else: for argVal, name in zip([args.captureConvRepr, args.captureCtxRepr, args.captureSpeakerAlign, args.capturePhoneAlign, args.capturePred, args.captureCPCCTCalign, args.captureCPCCTClogScores], ['conv_repr', 'ctx_repr', 'speaker_align', 'phone_align', 'pred', 'cpcctc_align', 'cpcctc_log_scores']): if argVal: whatToSave.append(name) ###assert len(whatToSave) > 0 captureOptions = { 'path': args.pathCaptureSave, 'eachEpochs': args.captureEachEpochs, 'what': whatToSave } seqCapture = filterSeqs(args.pathCaptureDS, seqNames, percentage=args.captureDSfreq, totalNum=args.captureDStotNr) print(f'Capture files: {len(seqCapture)}') else: seqCapture = None captureOptions = None if not args.onlyCapture: if args.debug: seqTrain = seqTrain[-1000:] seqVal = seqVal[-100:] phoneLabels, nPhones = None, None if args.supervised and args.pathPhone is not None: print("Loading the phone labels at " + args.pathPhone) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) print(f"{nPhones} phones found") print("") print(f'Loading audio data at {args.pathDB}') print("Loading the training dataset") trainDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print("Training dataset loaded") print("") print("Loading the validation dataset") valDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader) print("Validation dataset loaded") print("") else: phoneLabels, nPhones = None, None trainDataset = None valDataset = None if seqCapture is not None: if args.path_phone_data: print("Loading the phone labels at " + args.path_phone_data) phoneLabelsForCapture, _ = parseSeqLabels(args.path_phone_data) else: assert not args.capturePhoneAlign phoneLabelsForCapture = None print("Loading the capture dataset") captureDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqCapture, phoneLabelsForCapture, len(speakers), nProcessLoader=args.n_process_loader) print("Capture dataset loaded") print("") if args.captureSetStats: captureSetStatsCollector = statutil.constructStatCollectorFromSpecs(args.captureSetStats) else: captureSetStatsCollector = None else: captureDataset = None captureSetStatsCollector = None if args.load is not None: if args.gru_level is not None and args.gru_level > 0: updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level) else: updateConfig = None # loadBestNotLast = args.onlyCapture or args.only_classif_metric # could use this option for loading best state when not running actual training # but relying on CPC internal acc isn't very reliable # [!] caution - because of how they capture checkpoints, # they capture "best in this part of training" as "best" (apart from capturing current state) # so if best is in epoch 100 and training is paused and resumed from checkpoint # in epoch 150, checkpoint from epoch 200 has "best from epoch 150" saved as globally best # (but this is internal-CPC-score best anyway, which is quite vague) cpcModel, args.hiddenGar, args.hiddenEncoder = \ fl.loadModel(args.load, load_nullspace=args.nullspace, updateConfig=updateConfig) CPChiddenGar, CPChiddenEncoder = args.hiddenGar, args.hiddenEncoder if args.gru_level is not None and args.gru_level > 0: # Keep hidden units at LSTM layers on sequential batches if args.nullspace: cpcModel.cpc.gAR.keepHidden = True else: cpcModel.gAR.keepHidden = True else: # Encoder network encoderNet = fl.getEncoder(args) # AR Network arNet = fl.getAR(args) cpcModel = model.CPCModel(encoderNet, arNet) CPChiddenGar, CPChiddenEncoder = cpcModel.gAR.getDimOutput(), cpcModel.gEncoder.getDimOutput() batchSize = args.nGPU * args.batchSizeGPU cpcModel.supervised = args.supervised downsampling = cpcModel.cpc.gEncoder.DOWNSAMPLING if isinstance(cpcModel, model.CPCModelNullspace) else cpcModel.gEncoder.DOWNSAMPLING # Training criterion if args.load is not None and args.loadCriterion: cpcCriterion = loadCriterion(args.load[0], downsampling, len(speakers), nPhones) else: cpcCriterion = getCriterion(args, downsampling, len(speakers), nPhones) if loadOptimizer: state_dict = torch.load(args.load[0], 'cpu') cpcCriterion.load_state_dict(state_dict["cpcCriterion"]) cpcCriterion.cuda() cpcModel.cuda() # Optimizer g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters()) lr = args.learningRate optimizer = torch.optim.Adam(g_params, lr=lr, betas=(args.beta1, args.beta2), eps=args.epsilon) if loadOptimizer and not args.onlyCapture and not args.only_classif_metric: print("Loading optimizer " + args.load[0]) state_dict = torch.load(args.load[0], 'cpu') if "optimizer" in state_dict: optimizer.load_state_dict(state_dict["optimizer"]) # Checkpoint if args.pathCheckpoint is not None and not args.onlyCapture and not args.only_classif_metric: if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") with open(args.pathCheckpoint + "_args.json", 'w') as file: json.dump(vars(args), file, indent=2) scheduler = None if args.schedulerStep > 0: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.schedulerStep, gamma=0.5) if args.schedulerRamp is not None: n_epoch = args.schedulerRamp print(f"Ramp activated. n_e = {n_epoch}") scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: utils.ramp_scheduling_function( n_epoch, epoch), last_epoch=-1) if scheduler is None: scheduler = scheduler_ramp else: scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler], [0, args.schedulerRamp]) if scheduler is not None: print(f'Redoing {len(logs["epoch"])} scheduler steps') for i in range(len(logs["epoch"])): scheduler.step() print("cpcModel", cpcModel) print("cpcCriterion", cpcCriterion) cpcModel = torch.nn.DataParallel(cpcModel, device_ids=range(args.nGPU)).cuda() cpcCriterion = torch.nn.DataParallel(cpcCriterion, device_ids=range(args.nGPU)).cuda() if args.supervised_classif_metric: linsep_batch_size = args.linsepBatchSizeGPU * args.nGPU dim_features = CPChiddenEncoder if args.phone_get_encoded else CPChiddenGar dim_ctx_features = CPChiddenGar # for speakers using CNN encodings is not supported; could add but not very useful perhaps phoneLabelsData = None if args.path_phone_data: phoneLabelsData, nPhonesInData = parseSeqLabels(args.path_phone_data) if not args.CTCphones: print(f"Running phone separability with aligned phones") else: print(f"Running phone separability with CTC loss") def constructPhoneCriterionAndOptimizer(): if not args.CTCphones: # print(f"Running phone separability with aligned phones") phone_criterion = cr.PhoneCriterion(dim_features, nPhonesInData, args.phone_get_encoded, nLayers=args.linsep_net_layers) else: # print(f"Running phone separability with CTC loss") phone_criterion = cr.CTCPhoneCriterion(dim_features, nPhonesInData, args.phone_get_encoded, nLayers=args.linsep_net_layers) phone_criterion.cuda() phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU)) # Optimizer phone_g_params = list(phone_criterion.parameters()) phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr, betas=(args.linsep_beta1, args.linsep_beta2), eps=args.linsep_epsilon) return phone_criterion, phone_optimizer if args.speaker_sep: print(f"Running speaker separability") def constructSpeakerCriterionAndOptimizer(): speaker_criterion = cr.SpeakerCriterion(dim_ctx_features, len(speakers), nLayers=args.linsep_net_layers) speaker_criterion.cuda() speaker_criterion = torch.nn.DataParallel(speaker_criterion, device_ids=range(args.nGPU)) speaker_g_params = list(speaker_criterion.parameters()) speaker_optimizer = torch.optim.Adam(speaker_g_params, lr=args.linsep_lr, betas=(args.linsep_beta1, args.linsep_beta2), eps=args.linsep_epsilon) return speaker_criterion, speaker_optimizer linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabelsData, len(speakers)) linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabelsData, len(speakers)) linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True, numWorkers=0) linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False, numWorkers=0) def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch): log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch)) if not os.path.exists(log_path_for_epoch): os.makedirs(log_path_for_epoch) log_path_phoneme = os.path.join(log_path_for_epoch, "phoneme/") log_path_speaker = os.path.join(log_path_for_epoch, "speaker/") if not os.path.exists(log_path_phoneme): os.makedirs(log_path_phoneme) if not os.path.exists(log_path_speaker): os.makedirs(log_path_speaker) if args.linsep_checkpoint_dir: checpoint_path_for_epoch = os.path.join(args.linsep_checkpoint_dir, str(numOfEpoch)) checkpoint_path_phoneme = os.path.join(checpoint_path_for_epoch, "phoneme/") checkpoint_path_speaker = os.path.join(checpoint_path_for_epoch, "speaker/") if not os.path.exists(checkpoint_path_phoneme): os.makedirs(checkpoint_path_phoneme) if not os.path.exists(checkpoint_path_speaker): os.makedirs(checkpoint_path_speaker) locLogsPhone = {} locLogsSpeaker = {} if args.path_phone_data: phone_criterion, phone_optimizer = constructPhoneCriterionAndOptimizer() locLogsPhone = linsep.trainLinsepClassification( cpcMdl, phone_criterion, # combined with classification model before linsep_train_loader, linsep_val_loader, phone_optimizer, log_path_phoneme, args.linsep_task_logging_step, checkpoint_path_phoneme, args.linsep_n_epoch, cpcStateEpoch, 'phone') del phone_criterion del phone_optimizer if args.speaker_sep: speaker_criterion, speaker_optimizer = constructSpeakerCriterionAndOptimizer() locLogsSpeaker = linsep.trainLinsepClassification( cpcMdl, speaker_criterion, # combined with classification model before linsep_train_loader, linsep_val_loader, speaker_optimizer, log_path_speaker, args.linsep_task_logging_step, checkpoint_path_speaker, args.linsep_n_epoch, cpcStateEpoch, 'speaker') del speaker_criterion del speaker_optimizer locLogsPhone = {"phone_" + k: v for k, v in locLogsPhone.items()} locLogsSpeaker = {"speaker_" + k: v for k, v in locLogsSpeaker.items()} return {**locLogsPhone, **locLogsSpeaker} linsepClassificationTaskConfig = (args.linsep_classif_each_epochs, runLinsepClassificationTraining) else: linsepClassificationTaskConfig = (None, None) if not args.onlyCapture and not args.only_classif_metric: run(trainDataset, valDataset, (captureDataset, captureOptions, captureSetStatsCollector), linsepClassificationTaskConfig, batchSize, args.samplingType, cpcModel, cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler, logs) if args.onlyCapture: # caution [!] - will capture for last checkpoint (last saved state) if checkpoint directory given # to use specific checkpoint provide full checkpoint file path # will use "last state" and not "best in internal CPC accuracy" anyway onlyCapture( (captureDataset, captureOptions, captureSetStatsCollector), batchSize, cpcModel, cpcCriterion, logs) if args.only_classif_metric: # caution [!] - will use last checkpoint (last saved state) if checkpoint directory given # to use specific checkpoint provide full checkpoint file path # will use "last state" and not "best in internal CPC accuracy" anyway trainedEpoch = len(logs["epoch"]) - 1 # runPhonemeClassificationTraining created above if args.supervised_classif_metric runLinsepClassificationTraining(trainedEpoch, cpcModel, trainedEpoch)
def main(): # Parse and print args args = parse_args() logger.info(args) # Load the model print("") print(f"Loading model from {args.path_checkpoint}") if args.model == "cpc": sys.path.append(os.path.abspath(args.path_cpc)) from cpc.feature_loader import loadModel, FeatureModule model = loadModel([args.path_checkpoint])[0] else: from fairseq import checkpoint_utils def loadCheckpoint(path_checkpoint, path_data): """ Load lstm_lm model from checkpoint. """ # Set up the args Namespace model_args = argparse.Namespace( task="language_modeling", output_dictionary_size=-1, data=path_data, path=path_checkpoint ) # Load model models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path]) model = models[0] return model model = loadCheckpoint(args.path_checkpoint, args.path_data) device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu" # Register the hooks layer_outputs = {} def get_layer_output(name): def hook(model, input, output): if type(output) is tuple: layer_outputs[name] = output[0].detach().squeeze(1).cpu().numpy() elif type(output) is dict: layer_outputs[name] = output["x"].detach().squeeze(0).cpu().numpy() else: layer_outputs[name] = output.detach().squeeze(0).cpu().numpy() return hook layer_names = [] if args.model == "cpc": layer_name = os.path.basename(os.path.dirname(args.path_checkpoint)) layer_names.append(layer_name) model.gAR.register_forward_hook(get_layer_output(layer_name)) else: for i in range(len(model.encoder.layers)): layer_name = "layer_{}".format(i) layer_names.append(layer_name) model.encoder.layers[i].register_forward_hook(get_layer_output(layer_name)) layer_name = "last" layer_names.append(layer_name) model.register_forward_hook(get_layer_output(layer_name)) model = model.eval().to(device) print("Model loaded!") print(model) # Extract values from chosen layers and save them to files phonetic = "phonetic" datasets_path = os.path.join(args.path_data, phonetic) datasets = os.listdir(datasets_path) print(datasets) with torch.no_grad(): for dataset in datasets: print("> {}".format(dataset)) dataset_path = os.path.join(datasets_path, dataset) files = [f for f in os.listdir(dataset_path) if f.endswith(args.file_extension)] for i, f in enumerate(files): print("Progress {:2.1%}".format(i / len(files)), end="\r") input_f = os.path.join(dataset_path, f) x, sample_rate = sf.read(input_f) x = torch.tensor(x).float().reshape(1,-1).to(device) if args.model == "cpc": encodedData = model.gEncoder(x.unsqueeze(1)).permute(0, 2, 1) output = model.gAR(encodedData) else: output = model(x, features_only=True)["x"] for layer_name, value in layer_outputs.items(): output_dir = os.path.join(args.path_output_dir, layer_name, phonetic, dataset) Path(output_dir).mkdir(parents=True, exist_ok=True) out_f = os.path.join(output_dir, os.path.splitext(f)[0] + ".txt") np.savetxt(out_f, value)
def main(args): args = parseArgs(args) utils.set_seed(args.random_seed) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} loadOptimizer = False if args.pathCheckpoint is not None and not args.restart: cdata = fl.getCheckpointData(args.pathCheckpoint) if cdata is not None: data, logs, locArgs = cdata print(f"Checkpoint detected at {data}") fl.loadArgs(args, locArgs, forbiddenAttr={ "nGPU", "pathCheckpoint", "debug", "restart", "world_size", "n_nodes", "node_id", "n_gpu_per_node", "max_size_loaded" }) args.load, loadOptimizer = [data], True args.loadCriterion = True logs["logging_step"] = args.logging_step print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}') print('-' * 50) seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets if args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, seqNames) else: seqTrain = seqNames if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.99 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val') else: seqVal = filterSeqs(args.pathVal, seqNames) if args.debug: seqTrain = seqTrain[-1000:] seqVal = seqVal[-100:] phoneLabels, nPhones = None, None if args.supervised and args.pathPhone is not None: print("Loading the phone labels at " + args.pathPhone) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) print(f"{nPhones} phones found") print("") print(f'Loading audio data at {args.pathDB}') print("Loading the training dataset") trainDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print("Training dataset loaded") print("") print("Loading the validation dataset") valDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader) print("Validation dataset loaded") print("") if args.load is not None: cpcModel, args.hiddenGar, args.hiddenEncoder = \ fl.loadModel(args.load) else: # Encoder network encoderNet = fl.getEncoder(args) # AR Network arNet = fl.getAR(args) cpcModel = model.CPCModel(encoderNet, arNet) batchSize = args.nGPU * args.batchSizeGPU cpcModel.supervised = args.supervised # Training criterion if args.load is not None and args.loadCriterion: cpcCriterion = loadCriterion(args.load[0], cpcModel.gEncoder.DOWNSAMPLING, len(speakers), nPhones) else: cpcCriterion = getCriterion(args, cpcModel.gEncoder.DOWNSAMPLING, len(speakers), nPhones) if loadOptimizer: state_dict = torch.load(args.load[0], 'cpu') cpcCriterion.load_state_dict(state_dict["cpcCriterion"]) cpcCriterion.cuda() cpcModel.cuda() # Optimizer g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters()) lr = args.learningRate optimizer = torch.optim.Adam(g_params, lr=lr, betas=(args.beta1, args.beta2), eps=args.epsilon) if loadOptimizer: print("Loading optimizer " + args.load[0]) state_dict = torch.load(args.load[0], 'cpu') if "optimizer" in state_dict: optimizer.load_state_dict(state_dict["optimizer"]) # Checkpoint if args.pathCheckpoint is not None: if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") scheduler = None if args.schedulerStep > 0: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.schedulerStep, gamma=0.5) if args.schedulerRamp is not None: n_epoch = args.schedulerRamp print(f"Ramp activated. n_e = {n_epoch}") scheduler_ramp = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: utils.ramp_scheduling_function( n_epoch, epoch), last_epoch=-1) if scheduler is None: scheduler = scheduler_ramp else: scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler], [0, args.schedulerRamp]) if scheduler is not None: for i in range(len(logs["epoch"])): scheduler.step() cpcModel = torch.nn.DataParallel(cpcModel, device_ids=range(args.nGPU)).cuda() cpcCriterion = torch.nn.DataParallel(cpcCriterion, device_ids=range(args.nGPU)).cuda() run(trainDataset, valDataset, batchSize, args.samplingType, cpcModel, cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler, logs)