def run(feature_maker, criterion, train_loader, val_loader, optimizer, logs, n_epochs, path_checkpoint, label_key="speaker", centerpushSettings=None): start_epoch = len(logs["epoch"]) best_acc = -1 start_time = time.time() for epoch in range(start_epoch, n_epochs): logs_train = train_step(feature_maker, criterion, train_loader, optimizer, label_key=label_key, centerpushSettings=centerpushSettings) logs_val = val_step(feature_maker, criterion, val_loader, label_key=label_key, centerpushSettings=centerpushSettings) print('') print('_'*50) print(f'Ran {epoch + 1} epochs ' f'in {time.time() - start_time:.2f} seconds') utils.show_logs("Training loss", logs_train) utils.show_logs("Validation loss", logs_val) print('_'*50) print('') if logs_val["locAcc_val"] > best_acc: best_state = deepcopy(fl.get_module(feature_maker).state_dict()) best_acc = logs_val["locAcc_val"] logs["epoch"].append(epoch) for key, value in dict(logs_train, **logs_val).items(): if key not in logs: logs[key] = [None for x in range(epoch)] if isinstance(value, np.ndarray): value = value.tolist() logs[key].append(value) if (epoch % logs["saveStep"] == 0 and epoch > 0) or epoch == n_epochs - 1: model_state_dict = fl.get_module(feature_maker).state_dict() criterion_state_dict = fl.get_module(criterion).state_dict() fl.save_checkpoint(model_state_dict, criterion_state_dict, optimizer.state_dict(), best_state, f"{path_checkpoint}_{epoch}.pt") utils.save_logs(logs, f"{path_checkpoint}_logs.json")
#logs = {"epoch": [], "iter": [], "saveStep": -1} # saveStep=-1, save only best checkpoint! logs_test = test_step(feat_gen, clf, test_loader, optimizer, ep) print('') print('_' * 50) print(f'Ran {ep + 1} epochs ' f'in {time.time() - start_time:.2f} seconds') utils.show_logs("Training loss", logs_train) utils.show_logs("Test loss", logs_test) print('_' * 50) print('') if logs_test["locAcc_test"] > best_acc: best_state = deepcopy(fl.get_module(feat_gen).state_dict()) best_acc = logs_test["locAcc_test"] logs["epoch"].append(ep) for key, value in dict(logs_train, **logs_test).items(): if key not in logs: logs[key] = [None for x in range(ep)] if isinstance(value, np.ndarray): value = value.tolist() logs[key].append(value) if (ep % 10 == 0 and ep > 0) or ep == MAX_EPOCHS - 1: feat_gen_state_dict = fl.get_module(feat_gen).state_dict() clf_state_dict = fl.get_module(clf).state_dict() fl.save_checkpoint(feat_gen_state_dict, clf_state_dict, optimizer.state_dict(), best_state, f"{SAVE_PATH}_{ep}.pt") utils.save_logs(logs, f"{SAVE_PATH}_logs.json")
def run(trainDataset, valDataset, captureDatasetWithOptions, linsepClassificationTaskConfig, batchSize, samplingMode, cpcModel, cpcCriterion, nEpoch, pathCheckpoint, optimizer, scheduler, logs): startEpoch = len(logs["epoch"]) print(f"Running {nEpoch} epochs, now at {startEpoch}") bestAcc = 0 bestStateDict = None start_time = time.time() captureDataset, captureOptions, captureStatsCollector = captureDatasetWithOptions linsepEachEpochs, linsepFun = linsepClassificationTaskConfig assert (captureDataset is None and captureOptions is None) \ or (captureDataset is not None and captureOptions is not None) if captureOptions is not None: captureEachEpochs = captureOptions['eachEpochs'] print(f'DS sizes: train {str(len(trainDataset)) if trainDataset is not None else "-"}, ' f'val {str(len(valDataset)) if valDataset is not None else "-"}, capture ' f'{str(len(captureDataset)) if captureDataset is not None else "-"}') for epoch in range(startEpoch, nEpoch): print(f"Starting epoch {epoch}") utils.cpu_stats() trainLoader = trainDataset.getDataLoader(batchSize, samplingMode, True, numWorkers=0) valLoader = valDataset.getDataLoader(batchSize, 'sequential', False, numWorkers=0) if captureDataset is not None and epoch % captureEachEpochs == 0: captureLoader = captureDataset.getDataLoader(batchSize, 'sequential', False, numWorkers=0) print("Training dataset %d batches, Validation dataset %d batches, batch size %d" % (len(trainLoader), len(valLoader), batchSize)) locLogsTrain = trainStep(trainLoader, cpcModel, cpcCriterion, optimizer, scheduler, logs["logging_step"]) locLogsVal = valStep(valLoader, cpcModel, cpcCriterion) if captureDataset is not None and epoch % captureEachEpochs == 0: print(f"Capturing data for epoch {epoch}") captureStep(captureLoader, cpcModel, cpcCriterion, captureOptions, captureStatsCollector, epoch) currentAccuracy = float(locLogsVal["locAcc_val"].mean()) if currentAccuracy > bestAcc: bestStateDict = deepcopy(fl.get_module(cpcModel).state_dict()) locLogsLinsep = {} # this performs linsep task for the best CPC model up to date if linsepEachEpochs is not None and epoch !=0 and epoch % linsepEachEpochs == 0: # capturing for current CPC state after this epoch, relying on CPC internal accuracy is vague locLogsLinsep = linsepFun(epoch, cpcModel, epoch) print(f'Ran {epoch + 1} epochs ' f'in {time.time() - start_time:.2f} seconds') torch.cuda.empty_cache() for key, value in dict(locLogsTrain, **locLogsVal, **locLogsLinsep).items(): if key not in logs: logs[key] = [None for x in range(epoch)] if isinstance(value, np.ndarray): value = value.tolist() while len(logs[key]) < len(logs["epoch"]): logs[key].append(None) # for not-every-epoch-logged things logs[key].append(value) logs["epoch"].append(epoch) if pathCheckpoint is not None \ and (epoch % logs["saveStep"] == 0 or epoch == nEpoch-1): modelStateDict = fl.get_module(cpcModel).state_dict() criterionStateDict = fl.get_module(cpcCriterion).state_dict() fl.save_checkpoint(modelStateDict, criterionStateDict, optimizer.state_dict(), bestStateDict, f"{pathCheckpoint}_{epoch}.pt") utils.save_logs(logs, pathCheckpoint + "_logs.json")
def run(trainDataset, valDataset, batchSize, samplingMode, cpcModel, cpcCriterion, nEpoch, pathCheckpoint, optimizer, scheduler, logs): print(f"Running {nEpoch} epochs") startEpoch = len(logs["epoch"]) bestAcc = 0 bestStateDict = None start_time = time.time() for epoch in range(startEpoch, nEpoch): print(f"Starting epoch {epoch}") utils.cpu_stats() trainLoader = trainDataset.getDataLoader(batchSize, samplingMode, True, numWorkers=0) valLoader = valDataset.getDataLoader(batchSize, 'sequential', False, numWorkers=0) print( "Training dataset %d batches, Validation dataset %d batches, batch size %d" % (len(trainLoader), len(valLoader), batchSize)) locLogsTrain = trainStep(trainLoader, cpcModel, cpcCriterion, optimizer, scheduler, logs["logging_step"]) locLogsVal = valStep(valLoader, cpcModel, cpcCriterion) print(f'Ran {epoch + 1} epochs ' f'in {time.time() - start_time:.2f} seconds') torch.cuda.empty_cache() currentAccuracy = float(locLogsVal["locAcc_val"].mean()) if currentAccuracy > bestAcc: bestStateDict = fl.get_module(cpcModel).state_dict() for key, value in dict(locLogsTrain, **locLogsVal).items(): if key not in logs: logs[key] = [None for x in range(epoch)] if isinstance(value, np.ndarray): value = value.tolist() logs[key].append(value) logs["epoch"].append(epoch) if pathCheckpoint is not None \ and (epoch % logs["saveStep"] == 0 or epoch == nEpoch-1): modelStateDict = fl.get_module(cpcModel).state_dict() criterionStateDict = fl.get_module(cpcCriterion).state_dict() fl.save_checkpoint(modelStateDict, criterionStateDict, optimizer.state_dict(), bestStateDict, f"{pathCheckpoint}_{epoch}.pt") utils.save_logs(logs, pathCheckpoint + "_logs.json")