def main(args): # import ptvsd # ptvsd.enable_attach(('0.0.0.0', 7309)) # print("Attach debugger now") # ptvsd.wait_for_attach() args = parseArgs(args) utils.set_seed(args.random_seed) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} loadOptimizer = False os.makedirs(args.pathCheckpoint, exist_ok=True) if not args.onlyCapture and not args.only_classif_metric: json.dump(vars(args), open(os.path.join(args.pathCheckpoint, 'checkpoint_args.json'), 'wt')) if args.pathCheckpoint is not None and not args.restart: cdata = fl.getCheckpointData(args.pathCheckpoint) if cdata is not None: data, logs, locArgs = cdata print(f"Checkpoint detected at {data}") fl.loadArgs(args, locArgs, forbiddenAttr={"nGPU", "pathCheckpoint", "debug", "restart", "world_size", "n_nodes", "node_id", "n_gpu_per_node", "max_size_loaded"}) args.load, loadOptimizer = [data], True args.loadCriterion = True logs["logging_step"] = args.logging_step print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}') print('-' * 50) seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) if not args.onlyCapture or args.only_classif_metric: print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets if args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, seqNames) else: seqTrain = seqNames if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.99 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val') else: seqVal = filterSeqs(args.pathVal, seqNames) if args.pathCaptureDS is not None: assert args.pathCaptureSave is not None whatToSave = [] if args.captureEverything: whatToSave = ['conv_repr', 'ctx_repr', 'speaker_align', 'pred'] if args.path_phone_data: whatToSave.append('phone_align') if args.CPCCTC: whatToSave.append('cpcctc_align') whatToSave.append('cpcctc_log_scores') else: for argVal, name in zip([args.captureConvRepr, args.captureCtxRepr, args.captureSpeakerAlign, args.capturePhoneAlign, args.capturePred, args.captureCPCCTCalign, args.captureCPCCTClogScores], ['conv_repr', 'ctx_repr', 'speaker_align', 'phone_align', 'pred', 'cpcctc_align', 'cpcctc_log_scores']): if argVal: whatToSave.append(name) ###assert len(whatToSave) > 0 captureOptions = { 'path': args.pathCaptureSave, 'eachEpochs': args.captureEachEpochs, 'what': whatToSave } seqCapture = filterSeqs(args.pathCaptureDS, seqNames, percentage=args.captureDSfreq, totalNum=args.captureDStotNr) print(f'Capture files: {len(seqCapture)}') else: seqCapture = None captureOptions = None if not args.onlyCapture: if args.debug: seqTrain = seqTrain[-1000:] seqVal = seqVal[-100:] phoneLabels, nPhones = None, None if args.supervised and args.pathPhone is not None: print("Loading the phone labels at " + args.pathPhone) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) print(f"{nPhones} phones found") print("") print(f'Loading audio data at {args.pathDB}') print("Loading the training dataset") trainDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print("Training dataset loaded") print("") print("Loading the validation dataset") valDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader) print("Validation dataset loaded") print("") else: phoneLabels, nPhones = None, None trainDataset = None valDataset = None if seqCapture is not None: if args.path_phone_data: print("Loading the phone labels at " + args.path_phone_data) phoneLabelsForCapture, _ = parseSeqLabels(args.path_phone_data) else: assert not args.capturePhoneAlign phoneLabelsForCapture = None print("Loading the capture dataset") captureDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqCapture, phoneLabelsForCapture, len(speakers), nProcessLoader=args.n_process_loader) print("Capture dataset loaded") print("") if args.captureSetStats: captureSetStatsCollector = statutil.constructStatCollectorFromSpecs(args.captureSetStats) else: captureSetStatsCollector = None else: captureDataset = None captureSetStatsCollector = None if args.load is not None: if args.gru_level is not None and args.gru_level > 0: updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level) else: updateConfig = None # loadBestNotLast = args.onlyCapture or args.only_classif_metric # could use this option for loading best state when not running actual training # but relying on CPC internal acc isn't very reliable # [!] caution - because of how they capture checkpoints, # they capture "best in this part of training" as "best" (apart from capturing current state) # so if best is in epoch 100 and training is paused and resumed from checkpoint # in epoch 150, checkpoint from epoch 200 has "best from epoch 150" saved as globally best # (but this is internal-CPC-score best anyway, which is quite vague) cpcModel, args.hiddenGar, args.hiddenEncoder = \ fl.loadModel(args.load, load_nullspace=args.nullspace, updateConfig=updateConfig) CPChiddenGar, CPChiddenEncoder = args.hiddenGar, args.hiddenEncoder if args.gru_level is not None and args.gru_level > 0: # Keep hidden units at LSTM layers on sequential batches if args.nullspace: cpcModel.cpc.gAR.keepHidden = True else: cpcModel.gAR.keepHidden = True else: # Encoder network encoderNet = fl.getEncoder(args) # AR Network arNet = fl.getAR(args) cpcModel = model.CPCModel(encoderNet, arNet) CPChiddenGar, CPChiddenEncoder = cpcModel.gAR.getDimOutput(), cpcModel.gEncoder.getDimOutput() batchSize = args.nGPU * args.batchSizeGPU cpcModel.supervised = args.supervised downsampling = cpcModel.cpc.gEncoder.DOWNSAMPLING if isinstance(cpcModel, model.CPCModelNullspace) else cpcModel.gEncoder.DOWNSAMPLING # Training criterion if args.load is not None and args.loadCriterion: cpcCriterion = loadCriterion(args.load[0], downsampling, len(speakers), nPhones) else: cpcCriterion = getCriterion(args, downsampling, len(speakers), nPhones) if loadOptimizer: state_dict = torch.load(args.load[0], 'cpu') cpcCriterion.load_state_dict(state_dict["cpcCriterion"]) cpcCriterion.cuda() cpcModel.cuda() # Optimizer g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters()) lr = args.learningRate optimizer = torch.optim.Adam(g_params, lr=lr, betas=(args.beta1, args.beta2), eps=args.epsilon) if loadOptimizer and not args.onlyCapture and not args.only_classif_metric: print("Loading optimizer " + args.load[0]) state_dict = torch.load(args.load[0], 'cpu') if "optimizer" in state_dict: optimizer.load_state_dict(state_dict["optimizer"]) # Checkpoint if args.pathCheckpoint is not None and not args.onlyCapture and not args.only_classif_metric: if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") with open(args.pathCheckpoint + "_args.json", 'w') as file: json.dump(vars(args), file, indent=2) scheduler = None if args.schedulerStep > 0: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.schedulerStep, gamma=0.5) if args.schedulerRamp is not None: n_epoch = args.schedulerRamp print(f"Ramp activated. n_e = {n_epoch}") scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: utils.ramp_scheduling_function( n_epoch, epoch), last_epoch=-1) if scheduler is None: scheduler = scheduler_ramp else: scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler], [0, args.schedulerRamp]) if scheduler is not None: print(f'Redoing {len(logs["epoch"])} scheduler steps') for i in range(len(logs["epoch"])): scheduler.step() print("cpcModel", cpcModel) print("cpcCriterion", cpcCriterion) cpcModel = torch.nn.DataParallel(cpcModel, device_ids=range(args.nGPU)).cuda() cpcCriterion = torch.nn.DataParallel(cpcCriterion, device_ids=range(args.nGPU)).cuda() if args.supervised_classif_metric: linsep_batch_size = args.linsepBatchSizeGPU * args.nGPU dim_features = CPChiddenEncoder if args.phone_get_encoded else CPChiddenGar dim_ctx_features = CPChiddenGar # for speakers using CNN encodings is not supported; could add but not very useful perhaps phoneLabelsData = None if args.path_phone_data: phoneLabelsData, nPhonesInData = parseSeqLabels(args.path_phone_data) if not args.CTCphones: print(f"Running phone separability with aligned phones") else: print(f"Running phone separability with CTC loss") def constructPhoneCriterionAndOptimizer(): if not args.CTCphones: # print(f"Running phone separability with aligned phones") phone_criterion = cr.PhoneCriterion(dim_features, nPhonesInData, args.phone_get_encoded, nLayers=args.linsep_net_layers) else: # print(f"Running phone separability with CTC loss") phone_criterion = cr.CTCPhoneCriterion(dim_features, nPhonesInData, args.phone_get_encoded, nLayers=args.linsep_net_layers) phone_criterion.cuda() phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU)) # Optimizer phone_g_params = list(phone_criterion.parameters()) phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr, betas=(args.linsep_beta1, args.linsep_beta2), eps=args.linsep_epsilon) return phone_criterion, phone_optimizer if args.speaker_sep: print(f"Running speaker separability") def constructSpeakerCriterionAndOptimizer(): speaker_criterion = cr.SpeakerCriterion(dim_ctx_features, len(speakers), nLayers=args.linsep_net_layers) speaker_criterion.cuda() speaker_criterion = torch.nn.DataParallel(speaker_criterion, device_ids=range(args.nGPU)) speaker_g_params = list(speaker_criterion.parameters()) speaker_optimizer = torch.optim.Adam(speaker_g_params, lr=args.linsep_lr, betas=(args.linsep_beta1, args.linsep_beta2), eps=args.linsep_epsilon) return speaker_criterion, speaker_optimizer linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabelsData, len(speakers)) linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabelsData, len(speakers)) linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True, numWorkers=0) linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False, numWorkers=0) def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch): log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch)) if not os.path.exists(log_path_for_epoch): os.makedirs(log_path_for_epoch) log_path_phoneme = os.path.join(log_path_for_epoch, "phoneme/") log_path_speaker = os.path.join(log_path_for_epoch, "speaker/") if not os.path.exists(log_path_phoneme): os.makedirs(log_path_phoneme) if not os.path.exists(log_path_speaker): os.makedirs(log_path_speaker) if args.linsep_checkpoint_dir: checpoint_path_for_epoch = os.path.join(args.linsep_checkpoint_dir, str(numOfEpoch)) checkpoint_path_phoneme = os.path.join(checpoint_path_for_epoch, "phoneme/") checkpoint_path_speaker = os.path.join(checpoint_path_for_epoch, "speaker/") if not os.path.exists(checkpoint_path_phoneme): os.makedirs(checkpoint_path_phoneme) if not os.path.exists(checkpoint_path_speaker): os.makedirs(checkpoint_path_speaker) locLogsPhone = {} locLogsSpeaker = {} if args.path_phone_data: phone_criterion, phone_optimizer = constructPhoneCriterionAndOptimizer() locLogsPhone = linsep.trainLinsepClassification( cpcMdl, phone_criterion, # combined with classification model before linsep_train_loader, linsep_val_loader, phone_optimizer, log_path_phoneme, args.linsep_task_logging_step, checkpoint_path_phoneme, args.linsep_n_epoch, cpcStateEpoch, 'phone') del phone_criterion del phone_optimizer if args.speaker_sep: speaker_criterion, speaker_optimizer = constructSpeakerCriterionAndOptimizer() locLogsSpeaker = linsep.trainLinsepClassification( cpcMdl, speaker_criterion, # combined with classification model before linsep_train_loader, linsep_val_loader, speaker_optimizer, log_path_speaker, args.linsep_task_logging_step, checkpoint_path_speaker, args.linsep_n_epoch, cpcStateEpoch, 'speaker') del speaker_criterion del speaker_optimizer locLogsPhone = {"phone_" + k: v for k, v in locLogsPhone.items()} locLogsSpeaker = {"speaker_" + k: v for k, v in locLogsSpeaker.items()} return {**locLogsPhone, **locLogsSpeaker} linsepClassificationTaskConfig = (args.linsep_classif_each_epochs, runLinsepClassificationTraining) else: linsepClassificationTaskConfig = (None, None) if not args.onlyCapture and not args.only_classif_metric: run(trainDataset, valDataset, (captureDataset, captureOptions, captureSetStatsCollector), linsepClassificationTaskConfig, batchSize, args.samplingType, cpcModel, cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler, logs) if args.onlyCapture: # caution [!] - will capture for last checkpoint (last saved state) if checkpoint directory given # to use specific checkpoint provide full checkpoint file path # will use "last state" and not "best in internal CPC accuracy" anyway onlyCapture( (captureDataset, captureOptions, captureSetStatsCollector), batchSize, cpcModel, cpcCriterion, logs) if args.only_classif_metric: # caution [!] - will use last checkpoint (last saved state) if checkpoint directory given # to use specific checkpoint provide full checkpoint file path # will use "last state" and not "best in internal CPC accuracy" anyway trainedEpoch = len(logs["epoch"]) - 1 # runPhonemeClassificationTraining created above if args.supervised_classif_metric runLinsepClassificationTraining(trainedEpoch, cpcModel, trainedEpoch)
def main(args): args = parseArgs(args) utils.set_seed(args.random_seed) logs = {"epoch": [], "iter": [], "saveStep": args.save_step} loadOptimizer = False if args.pathCheckpoint is not None and not args.restart: cdata = fl.getCheckpointData(args.pathCheckpoint) if cdata is not None: data, logs, locArgs = cdata print(f"Checkpoint detected at {data}") fl.loadArgs(args, locArgs, forbiddenAttr={ "nGPU", "pathCheckpoint", "debug", "restart", "world_size", "n_nodes", "node_id", "n_gpu_per_node", "max_size_loaded" }) args.load, loadOptimizer = [data], True args.loadCriterion = True logs["logging_step"] = args.logging_step print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}') print('-' * 50) seqNames, speakers = findAllSeqs(args.pathDB, extension=args.file_extension, loadCache=not args.ignore_cache) print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers') # Datasets if args.pathTrain is not None: seqTrain = filterSeqs(args.pathTrain, seqNames) else: seqTrain = seqNames if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.99 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val') else: seqVal = filterSeqs(args.pathVal, seqNames) if args.debug: seqTrain = seqTrain[-1000:] seqVal = seqVal[-100:] phoneLabels, nPhones = None, None if args.supervised and args.pathPhone is not None: print("Loading the phone labels at " + args.pathPhone) phoneLabels, nPhones = parseSeqLabels(args.pathPhone) print(f"{nPhones} phones found") print("") print(f'Loading audio data at {args.pathDB}') print("Loading the training dataset") trainDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader, MAX_SIZE_LOADED=args.max_size_loaded) print("Training dataset loaded") print("") print("Loading the validation dataset") valDataset = AudioBatchData(args.pathDB, args.sizeWindow, seqVal, phoneLabels, len(speakers), nProcessLoader=args.n_process_loader) print("Validation dataset loaded") print("") if args.load is not None: cpcModel, args.hiddenGar, args.hiddenEncoder = \ fl.loadModel(args.load) else: # Encoder network encoderNet = fl.getEncoder(args) # AR Network arNet = fl.getAR(args) cpcModel = model.CPCModel(encoderNet, arNet) batchSize = args.nGPU * args.batchSizeGPU cpcModel.supervised = args.supervised # Training criterion if args.load is not None and args.loadCriterion: cpcCriterion = loadCriterion(args.load[0], cpcModel.gEncoder.DOWNSAMPLING, len(speakers), nPhones) else: cpcCriterion = getCriterion(args, cpcModel.gEncoder.DOWNSAMPLING, len(speakers), nPhones) if loadOptimizer: state_dict = torch.load(args.load[0], 'cpu') cpcCriterion.load_state_dict(state_dict["cpcCriterion"]) cpcCriterion.cuda() cpcModel.cuda() # Optimizer g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters()) lr = args.learningRate optimizer = torch.optim.Adam(g_params, lr=lr, betas=(args.beta1, args.beta2), eps=args.epsilon) if loadOptimizer: print("Loading optimizer " + args.load[0]) state_dict = torch.load(args.load[0], 'cpu') if "optimizer" in state_dict: optimizer.load_state_dict(state_dict["optimizer"]) # Checkpoint if args.pathCheckpoint is not None: if not os.path.isdir(args.pathCheckpoint): os.mkdir(args.pathCheckpoint) args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint") scheduler = None if args.schedulerStep > 0: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.schedulerStep, gamma=0.5) if args.schedulerRamp is not None: n_epoch = args.schedulerRamp print(f"Ramp activated. n_e = {n_epoch}") scheduler_ramp = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: utils.ramp_scheduling_function( n_epoch, epoch), last_epoch=-1) if scheduler is None: scheduler = scheduler_ramp else: scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler], [0, args.schedulerRamp]) if scheduler is not None: for i in range(len(logs["epoch"])): scheduler.step() cpcModel = torch.nn.DataParallel(cpcModel, device_ids=range(args.nGPU)).cuda() cpcCriterion = torch.nn.DataParallel(cpcCriterion, device_ids=range(args.nGPU)).cuda() run(trainDataset, valDataset, batchSize, args.samplingType, cpcModel, cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler, logs)