def main(argv): args = parse_args(argv) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} load_criterion = False seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) if args.model == "cpc": def loadCPCFeatureMaker(pathCheckpoint, gru_level=-1, get_encoded=False, keep_hidden=True): """ Load CPC Feature Maker from CPC checkpoint file. """ # Set LSTM level if gru_level is not None and gru_level > 0: updateConfig = argparse.Namespace(nLevelsGRU=gru_level) else: updateConfig = None # Load CPC model model, nHiddenGar, nHiddenEncoder = fl.loadModel(pathCheckpoint, updateConfig=updateConfig) # Keep hidden units at LSTM layers on sequential batches model.gAR.keepHidden = keep_hidden # Build CPC Feature Maker from CPC model #featureMaker = fl.FeatureModule(model, get_encoded=get_encoded) #return featureMaker return model, nHiddenGar, nHiddenEncoder if args.gru_level is not None and args.gru_level > 0: model, hidden_gar, hidden_encoder = loadCPCFeatureMaker(args.load, gru_level=args.gru_level) else: model, hidden_gar, hidden_encoder = fl.loadModel(args.load, loadStateDict=not args.no_pretraining) dim_features = hidden_encoder if args.get_encoded else hidden_gar else: sys.path.append(os.path.abspath(args.path_fairseq)) from fairseq import checkpoint_utils def loadCheckpoint(path_checkpoint, path_data): """ Load lstm_lm model from checkpoint. """ # Set up the args Namespace model_args = argparse.Namespace( task="language_modeling", output_dictionary_size=-1, data=path_data, path=path_checkpoint ) # Load model models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path]) model = models[0] return model model = loadCheckpoint(args.load[0], args.pathDB) dim_features = 768 dim_inter = args.dim_inter # Now the criterion if args.mode == "phonemes_nullspace" or args.mode == "speakers_nullspace": speakers_factorized = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers)) speakers_factorized.load_state_dict(torch.load(args.path_speakers_factorized)["cpcCriterion"]) for param in speakers_factorized.parameters(): param.requires_grad = False def my_nullspace(At, rcond=None): ut, st, vht = torch.Tensor.svd(At, some=False,compute_uv=True) vht=vht.T Mt, Nt = ut.shape[0], vht.shape[1] if rcond is None: rcondt = torch.finfo(st.dtype).eps * max(Mt, Nt) tolt = torch.max(st) * rcondt numt= torch.sum(st > tolt, dtype=int) nullspace = vht[numt:,:].T.cpu().conj() # nullspace.backward(torch.ones_like(nullspace),retain_graph=True) return nullspace dim_features = dim_features - dim_inter nullspace = my_nullspace(speakers_factorized.linearSpeakerClassifier[0].weight) model = CPCModelNullspace(model, nullspace) phone_labels = None if args.pathPhone is not None: phone_labels, n_phones = parseSeqLabels(args.pathPhone) label_key = 'phone' if not args.CTC: print(f"Running phone separability with aligned phones") criterion = cr.PhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running phone separability with CTC loss") criterion = cr.CTCPhoneCriterion(dim_features, n_phones, args.get_encoded) else: label_key = 'speaker' print(f"Running speaker separability") if args.mode == "speakers_factorized": criterion = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers)) else: criterion = cr.SpeakerCriterion(dim_features, len(speakers)) criterion.cuda() criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU)) model.cuda() model = torch.nn.DataParallel(model, device_ids=range(args.nGPU)) # Dataset seq_train = filterSeqs(args.pathTrain, seqNames) seq_val = filterSeqs(args.pathVal, seqNames) if args.debug: seq_train = seq_train[:1000] seq_val = seq_val[:100] db_train = AudioBatchData(args.pathDB, args.size_window, seq_train, phone_labels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) db_val = AudioBatchData(args.pathDB, args.size_window, seq_val, phone_labels, len(speakers), nProcessLoader=args.n_process_loader) batch_size = args.batchSizeGPU * args.nGPU train_loader = db_train.getDataLoader(batch_size, "uniform", True, numWorkers=0) val_loader = db_val.getDataLoader(batch_size, 'sequential', False, numWorkers=0) # Optimizer g_params = list(criterion.parameters()) model.optimize = False model.eval() if args.unfrozen: print("Working in full fine-tune mode") g_params += list(model.parameters()) model.optimize = True else: print("Working with frozen features") for g in model.parameters(): g.requires_grad = False optimizer = torch.optim.Adam(g_params, lr=args.lr, betas=(args.beta1, args.beta2), eps=args.epsilon) # Checkpoint directory args.pathCheckpoint = Path(args.pathCheckpoint) args.pathCheckpoint.mkdir(exist_ok=True) args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint") with open(f"{args.pathCheckpoint}_args.json", 'w') as file: json.dump(vars(args), file, indent=2) if args.centerpushFile: clustersFileExt = args.centerpushFile.split('.')[-1] assert clustersFileExt in ('pt', 'npy', 'txt') if clustersFileExt == 'npy': centers = np.load(args.centerpushFile) elif clustersFileExt == 'txt': centers = np.genfromtxt(args.centerpushFile) elif clustersFileExt == 'pt': # assuming it's a checkpoint centers = torch.load(args.centerpushFile, map_location=torch.device('cpu'))['state_dict']['Ck'] centers = torch.reshape(centers, centers.shape[1:]).numpy() centers = torch.tensor(centers).cuda() centerpushSettings = (centers, args.centerpushDeg) else: centerpushSettings = None run(model, criterion, train_loader, val_loader, optimizer, logs, args.n_epoch, args.pathCheckpoint, label_key=label_key, centerpushSettings=centerpushSettings)
if args.debug: nsamples = 1000 print(f"Debug mode activated, get only {nsamples} samples!") shuffle(seqNames) seqNames = seqNames[:nsamples] if args.getDistanceEstimation: shuffle(seqNames) seqNames = seqNames[:5000] print("") print(f'Loading audio data at {args.pathDB}') start_time = time.time() dataset = AudioBatchData(args.pathDB, args.sizeWindow, seqNames, None, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print(f"Dataset loaded in {time.time()-start_time} seconds !") print("") nGPUs = torch.cuda.device_count() batchSize = args.batchSizeGPU * nGPUs trainLoader = dataset.getDataLoader(batchSize, "uniform", False, numWorkers=0) print(f"Length of dataLoader: {len(trainLoader)}") print("")
def main(argv): args = parse_args(argv) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} load_criterion = False seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) model, hidden_gar, hidden_encoder = fl.loadModel( args.load, loadStateDict=not args.no_pretraining) model.cuda() model = torch.nn.DataParallel(model, device_ids=range(args.nGPU)) dim_features = hidden_encoder if args.get_encoded else hidden_gar # Now the criterion phone_labels = None if args.pathPhone is not None: phone_labels, n_phones = parseSeqLabels(args.pathPhone) if not args.CTC: print(f"Running phone separability with aligned phones") criterion = cr.PhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running phone separability with CTC loss") criterion = cr.CTCPhoneCriterion(dim_features, n_phones, args.get_encoded) else: print(f"Running speaker separability") criterion = cr.SpeakerCriterion(dim_features, len(speakers)) criterion.cuda() criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU)) # Dataset seq_train = filterSeqs(args.pathTrain, seqNames) seq_val = filterSeqs(args.pathVal, seqNames) if args.debug: seq_train = seq_train[:1000] seq_val = seq_val[:100] db_train = AudioBatchData(args.pathDB, args.size_window, seq_train, phone_labels, len(speakers)) db_val = AudioBatchData(args.pathDB, args.size_window, seq_val, phone_labels, len(speakers)) batch_size = args.batchSizeGPU * args.nGPU train_loader = db_train.getDataLoader(batch_size, "uniform", True, numWorkers=0) val_loader = db_val.getDataLoader(batch_size, 'sequential', False, numWorkers=0) # Optimizer g_params = list(criterion.parameters()) model.optimize = False model.eval() if args.unfrozen: print("Working in full fine-tune mode") g_params += list(model.parameters()) model.optimize = True else: print("Working with frozen features") for g in model.parameters(): g.requires_grad = False optimizer = torch.optim.Adam(g_params, lr=args.lr, betas=(args.beta1, args.beta2), eps=args.epsilon) # Checkpoint directory args.pathCheckpoint = Path(args.pathCheckpoint) args.pathCheckpoint.mkdir(exist_ok=True) args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint") with open(f"{args.pathCheckpoint}_args.json", 'w') as file: json.dump(vars(args), file, indent=2) run(model, criterion, train_loader, val_loader, optimizer, logs, args.n_epoch, args.pathCheckpoint)
def main(args): # import ptvsd # ptvsd.enable_attach(('0.0.0.0', 7309)) # print("Attach debugger now") # ptvsd.wait_for_attach() args = parseArgs(args) utils.set_seed(args.random_seed) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} loadOptimizer = False os.makedirs(args.pathCheckpoint, exist_ok=True) if not args.onlyCapture and not args.only_classif_metric: json.dump(vars(args), open(os.path.join(args.pathCheckpoint, 'checkpoint_args.json'), 'wt')) if args.pathCheckpoint is not None and not args.restart: cdata = fl.getCheckpointData(args.pathCheckpoint) if cdata is not None: data, logs, locArgs = cdata print(f"Checkpoint detected at {data}") fl.loadArgs(args, locArgs, forbiddenAttr={"nGPU", "pathCheckpoint", "debug", "restart", "world_size", "n_nodes", "node_id", "n_gpu_per_node", "max_size_loaded"}) args.load, loadOptimizer = [data], True args.loadCriterion = True logs["logging_step"] = args.logging_step print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}') print('-' * 50) seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) if not args.onlyCapture or args.only_classif_metric: print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets if args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, seqNames) else: seqTrain = seqNames if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.99 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val') else: seqVal = filterSeqs(args.pathVal, seqNames) if args.pathCaptureDS is not None: assert args.pathCaptureSave is not None whatToSave = [] if args.captureEverything: whatToSave = ['conv_repr', 'ctx_repr', 'speaker_align', 'pred'] if args.path_phone_data: whatToSave.append('phone_align') if args.CPCCTC: whatToSave.append('cpcctc_align') whatToSave.append('cpcctc_log_scores') else: for argVal, name in zip([args.captureConvRepr, args.captureCtxRepr, args.captureSpeakerAlign, args.capturePhoneAlign, args.capturePred, args.captureCPCCTCalign, args.captureCPCCTClogScores], ['conv_repr', 'ctx_repr', 'speaker_align', 'phone_align', 'pred', 'cpcctc_align', 'cpcctc_log_scores']): if argVal: whatToSave.append(name) ###assert len(whatToSave) > 0 captureOptions = { 'path': args.pathCaptureSave, 'eachEpochs': args.captureEachEpochs, 'what': whatToSave } seqCapture = filterSeqs(args.pathCaptureDS, seqNames, percentage=args.captureDSfreq, totalNum=args.captureDStotNr) print(f'Capture files: {len(seqCapture)}') else: seqCapture = None captureOptions = None if not args.onlyCapture: if args.debug: seqTrain = seqTrain[-1000:] seqVal = seqVal[-100:] phoneLabels, nPhones = None, None if args.supervised and args.pathPhone is not None: print("Loading the phone labels at " + args.pathPhone) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) print(f"{nPhones} phones found") print("") print(f'Loading audio data at {args.pathDB}') print("Loading the training dataset") trainDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print("Training dataset loaded") print("") print("Loading the validation dataset") valDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader) print("Validation dataset loaded") print("") else: phoneLabels, nPhones = None, None trainDataset = None valDataset = None if seqCapture is not None: if args.path_phone_data: print("Loading the phone labels at " + args.path_phone_data) phoneLabelsForCapture, _ = parseSeqLabels(args.path_phone_data) else: assert not args.capturePhoneAlign phoneLabelsForCapture = None print("Loading the capture dataset") captureDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqCapture, phoneLabelsForCapture, len(speakers), nProcessLoader=args.n_process_loader) print("Capture dataset loaded") print("") if args.captureSetStats: captureSetStatsCollector = statutil.constructStatCollectorFromSpecs(args.captureSetStats) else: captureSetStatsCollector = None else: captureDataset = None captureSetStatsCollector = None if args.load is not None: if args.gru_level is not None and args.gru_level > 0: updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level) else: updateConfig = None # loadBestNotLast = args.onlyCapture or args.only_classif_metric # could use this option for loading best state when not running actual training # but relying on CPC internal acc isn't very reliable # [!] caution - because of how they capture checkpoints, # they capture "best in this part of training" as "best" (apart from capturing current state) # so if best is in epoch 100 and training is paused and resumed from checkpoint # in epoch 150, checkpoint from epoch 200 has "best from epoch 150" saved as globally best # (but this is internal-CPC-score best anyway, which is quite vague) cpcModel, args.hiddenGar, args.hiddenEncoder = \ fl.loadModel(args.load, load_nullspace=args.nullspace, updateConfig=updateConfig) CPChiddenGar, CPChiddenEncoder = args.hiddenGar, args.hiddenEncoder if args.gru_level is not None and args.gru_level > 0: # Keep hidden units at LSTM layers on sequential batches if args.nullspace: cpcModel.cpc.gAR.keepHidden = True else: cpcModel.gAR.keepHidden = True else: # Encoder network encoderNet = fl.getEncoder(args) # AR Network arNet = fl.getAR(args) cpcModel = model.CPCModel(encoderNet, arNet) CPChiddenGar, CPChiddenEncoder = cpcModel.gAR.getDimOutput(), cpcModel.gEncoder.getDimOutput() batchSize = args.nGPU * args.batchSizeGPU cpcModel.supervised = args.supervised downsampling = cpcModel.cpc.gEncoder.DOWNSAMPLING if isinstance(cpcModel, model.CPCModelNullspace) else cpcModel.gEncoder.DOWNSAMPLING # Training criterion if args.load is not None and args.loadCriterion: cpcCriterion = loadCriterion(args.load[0], downsampling, len(speakers), nPhones) else: cpcCriterion = getCriterion(args, downsampling, len(speakers), nPhones) if loadOptimizer: state_dict = torch.load(args.load[0], 'cpu') cpcCriterion.load_state_dict(state_dict["cpcCriterion"]) cpcCriterion.cuda() cpcModel.cuda() # Optimizer g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters()) lr = args.learningRate optimizer = torch.optim.Adam(g_params, lr=lr, betas=(args.beta1, args.beta2), eps=args.epsilon) if loadOptimizer and not args.onlyCapture and not args.only_classif_metric: print("Loading optimizer " + args.load[0]) state_dict = torch.load(args.load[0], 'cpu') if "optimizer" in state_dict: optimizer.load_state_dict(state_dict["optimizer"]) # Checkpoint if args.pathCheckpoint is not None and not args.onlyCapture and not args.only_classif_metric: if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") with open(args.pathCheckpoint + "_args.json", 'w') as file: json.dump(vars(args), file, indent=2) scheduler = None if args.schedulerStep > 0: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.schedulerStep, gamma=0.5) if args.schedulerRamp is not None: n_epoch = args.schedulerRamp print(f"Ramp activated. n_e = {n_epoch}") scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: utils.ramp_scheduling_function( n_epoch, epoch), last_epoch=-1) if scheduler is None: scheduler = scheduler_ramp else: scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler], [0, args.schedulerRamp]) if scheduler is not None: print(f'Redoing {len(logs["epoch"])} scheduler steps') for i in range(len(logs["epoch"])): scheduler.step() print("cpcModel", cpcModel) print("cpcCriterion", cpcCriterion) cpcModel = torch.nn.DataParallel(cpcModel, device_ids=range(args.nGPU)).cuda() cpcCriterion = torch.nn.DataParallel(cpcCriterion, device_ids=range(args.nGPU)).cuda() if args.supervised_classif_metric: linsep_batch_size = args.linsepBatchSizeGPU * args.nGPU dim_features = CPChiddenEncoder if args.phone_get_encoded else CPChiddenGar dim_ctx_features = CPChiddenGar # for speakers using CNN encodings is not supported; could add but not very useful perhaps phoneLabelsData = None if args.path_phone_data: phoneLabelsData, nPhonesInData = parseSeqLabels(args.path_phone_data) if not args.CTCphones: print(f"Running phone separability with aligned phones") else: print(f"Running phone separability with CTC loss") def constructPhoneCriterionAndOptimizer(): if not args.CTCphones: # print(f"Running phone separability with aligned phones") phone_criterion = cr.PhoneCriterion(dim_features, nPhonesInData, args.phone_get_encoded, nLayers=args.linsep_net_layers) else: # print(f"Running phone separability with CTC loss") phone_criterion = cr.CTCPhoneCriterion(dim_features, nPhonesInData, args.phone_get_encoded, nLayers=args.linsep_net_layers) phone_criterion.cuda() phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU)) # Optimizer phone_g_params = list(phone_criterion.parameters()) phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr, betas=(args.linsep_beta1, args.linsep_beta2), eps=args.linsep_epsilon) return phone_criterion, phone_optimizer if args.speaker_sep: print(f"Running speaker separability") def constructSpeakerCriterionAndOptimizer(): speaker_criterion = cr.SpeakerCriterion(dim_ctx_features, len(speakers), nLayers=args.linsep_net_layers) speaker_criterion.cuda() speaker_criterion = torch.nn.DataParallel(speaker_criterion, device_ids=range(args.nGPU)) speaker_g_params = list(speaker_criterion.parameters()) speaker_optimizer = torch.optim.Adam(speaker_g_params, lr=args.linsep_lr, betas=(args.linsep_beta1, args.linsep_beta2), eps=args.linsep_epsilon) return speaker_criterion, speaker_optimizer linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabelsData, len(speakers)) linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabelsData, len(speakers)) linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True, numWorkers=0) linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False, numWorkers=0) def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch): log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch)) if not os.path.exists(log_path_for_epoch): os.makedirs(log_path_for_epoch) log_path_phoneme = os.path.join(log_path_for_epoch, "phoneme/") log_path_speaker = os.path.join(log_path_for_epoch, "speaker/") if not os.path.exists(log_path_phoneme): os.makedirs(log_path_phoneme) if not os.path.exists(log_path_speaker): os.makedirs(log_path_speaker) if args.linsep_checkpoint_dir: checpoint_path_for_epoch = os.path.join(args.linsep_checkpoint_dir, str(numOfEpoch)) checkpoint_path_phoneme = os.path.join(checpoint_path_for_epoch, "phoneme/") checkpoint_path_speaker = os.path.join(checpoint_path_for_epoch, "speaker/") if not os.path.exists(checkpoint_path_phoneme): os.makedirs(checkpoint_path_phoneme) if not os.path.exists(checkpoint_path_speaker): os.makedirs(checkpoint_path_speaker) locLogsPhone = {} locLogsSpeaker = {} if args.path_phone_data: phone_criterion, phone_optimizer = constructPhoneCriterionAndOptimizer() locLogsPhone = linsep.trainLinsepClassification( cpcMdl, phone_criterion, # combined with classification model before linsep_train_loader, linsep_val_loader, phone_optimizer, log_path_phoneme, args.linsep_task_logging_step, checkpoint_path_phoneme, args.linsep_n_epoch, cpcStateEpoch, 'phone') del phone_criterion del phone_optimizer if args.speaker_sep: speaker_criterion, speaker_optimizer = constructSpeakerCriterionAndOptimizer() locLogsSpeaker = linsep.trainLinsepClassification( cpcMdl, speaker_criterion, # combined with classification model before linsep_train_loader, linsep_val_loader, speaker_optimizer, log_path_speaker, args.linsep_task_logging_step, checkpoint_path_speaker, args.linsep_n_epoch, cpcStateEpoch, 'speaker') del speaker_criterion del speaker_optimizer locLogsPhone = {"phone_" + k: v for k, v in locLogsPhone.items()} locLogsSpeaker = {"speaker_" + k: v for k, v in locLogsSpeaker.items()} return {**locLogsPhone, **locLogsSpeaker} linsepClassificationTaskConfig = (args.linsep_classif_each_epochs, runLinsepClassificationTraining) else: linsepClassificationTaskConfig = (None, None) if not args.onlyCapture and not args.only_classif_metric: run(trainDataset, valDataset, (captureDataset, captureOptions, captureSetStatsCollector), linsepClassificationTaskConfig, batchSize, args.samplingType, cpcModel, cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler, logs) if args.onlyCapture: # caution [!] - will capture for last checkpoint (last saved state) if checkpoint directory given # to use specific checkpoint provide full checkpoint file path # will use "last state" and not "best in internal CPC accuracy" anyway onlyCapture( (captureDataset, captureOptions, captureSetStatsCollector), batchSize, cpcModel, cpcCriterion, logs) if args.only_classif_metric: # caution [!] - will use last checkpoint (last saved state) if checkpoint directory given # to use specific checkpoint provide full checkpoint file path # will use "last state" and not "best in internal CPC accuracy" anyway trainedEpoch = len(logs["epoch"]) - 1 # runPhonemeClassificationTraining created above if args.supervised_classif_metric runLinsepClassificationTraining(trainedEpoch, cpcModel, trainedEpoch)
def main(args): args = parseArgs(args) utils.set_seed(args.random_seed) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} loadOptimizer = False if args.pathCheckpoint is not None and not args.restart: cdata = fl.getCheckpointData(args.pathCheckpoint) if cdata is not None: data, logs, locArgs = cdata print(f"Checkpoint detected at {data}") fl.loadArgs(args, locArgs, forbiddenAttr={ "nGPU", "pathCheckpoint", "debug", "restart", "world_size", "n_nodes", "node_id", "n_gpu_per_node", "max_size_loaded" }) args.load, loadOptimizer = [data], True args.loadCriterion = True logs["logging_step"] = args.logging_step print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}') print('-' * 50) seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets if args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, seqNames) else: seqTrain = seqNames if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.99 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val') else: seqVal = filterSeqs(args.pathVal, seqNames) if args.debug: seqTrain = seqTrain[-1000:] seqVal = seqVal[-100:] phoneLabels, nPhones = None, None if args.supervised and args.pathPhone is not None: print("Loading the phone labels at " + args.pathPhone) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) print(f"{nPhones} phones found") print("") print(f'Loading audio data at {args.pathDB}') print("Loading the training dataset") trainDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print("Training dataset loaded") print("") print("Loading the validation dataset") valDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader) print("Validation dataset loaded") print("") if args.load is not None: cpcModel, args.hiddenGar, args.hiddenEncoder = \ fl.loadModel(args.load) else: # Encoder network encoderNet = fl.getEncoder(args) # AR Network arNet = fl.getAR(args) cpcModel = model.CPCModel(encoderNet, arNet) batchSize = args.nGPU * args.batchSizeGPU cpcModel.supervised = args.supervised # Training criterion if args.load is not None and args.loadCriterion: cpcCriterion = loadCriterion(args.load[0], cpcModel.gEncoder.DOWNSAMPLING, len(speakers), nPhones) else: cpcCriterion = getCriterion(args, cpcModel.gEncoder.DOWNSAMPLING, len(speakers), nPhones) if loadOptimizer: state_dict = torch.load(args.load[0], 'cpu') cpcCriterion.load_state_dict(state_dict["cpcCriterion"]) cpcCriterion.cuda() cpcModel.cuda() # Optimizer g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters()) lr = args.learningRate optimizer = torch.optim.Adam(g_params, lr=lr, betas=(args.beta1, args.beta2), eps=args.epsilon) if loadOptimizer: print("Loading optimizer " + args.load[0]) state_dict = torch.load(args.load[0], 'cpu') if "optimizer" in state_dict: optimizer.load_state_dict(state_dict["optimizer"]) # Checkpoint if args.pathCheckpoint is not None: if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") scheduler = None if args.schedulerStep > 0: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.schedulerStep, gamma=0.5) if args.schedulerRamp is not None: n_epoch = args.schedulerRamp print(f"Ramp activated. n_e = {n_epoch}") scheduler_ramp = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: utils.ramp_scheduling_function( n_epoch, epoch), last_epoch=-1) if scheduler is None: scheduler = scheduler_ramp else: scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler], [0, args.schedulerRamp]) if scheduler is not None: for i in range(len(logs["epoch"])): scheduler.step() cpcModel = torch.nn.DataParallel(cpcModel, device_ids=range(args.nGPU)).cuda() cpcCriterion = torch.nn.DataParallel(cpcCriterion, device_ids=range(args.nGPU)).cuda() run(trainDataset, valDataset, batchSize, args.samplingType, cpcModel, cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler, logs)
cnt_spk = np.zeros(len(speakers)) for seqn in seqNames: spk_id, fname = seqn cnt_spk[spk_id] += 1 if cnt_spk[spk_id] > 60: seqTest.append(seqn) else: seqTrain.append(seqn) # seq_train = filterSeqs(args.pathTrain, seqNames) # seq_val = filterSeqs(args.pathVal, seqNames) db_train = AudioBatchData( path='../../../CPC_librispeech/dataset/LibriSpeech/train-clean-100/', sizeWindow=20480, seqNames=seqTrain, phoneLabelsDict=None, nSpeakers=len(speakers), nProcessLoader=50, MAX_SIZE_LOADED=400000000) train_loader = db_train.getDataLoader( batchSize=64, type='uniform', #'sequential', randomOffset=False, numWorkers=0) db_test = AudioBatchData( path='../../../CPC_librispeech/dataset/LibriSpeech/train-clean-100/', sizeWindow=20480, seqNames=seqTest, phoneLabelsDict=None, nSpeakers=len(speakers),