def run(feature_maker,
        criterion,
        train_loader,
        val_loader,
        optimizer,
        logs,
        n_epochs,
        path_checkpoint,
        label_key="speaker",
        centerpushSettings=None):

    start_epoch = len(logs["epoch"])
    best_acc = -1

    start_time = time.time()

    for epoch in range(start_epoch, n_epochs):

        logs_train = train_step(feature_maker, criterion, train_loader,
                                optimizer, label_key=label_key, centerpushSettings=centerpushSettings)
        logs_val = val_step(feature_maker, criterion, val_loader, label_key=label_key, centerpushSettings=centerpushSettings)

        print('')
        print('_'*50)
        print(f'Ran {epoch + 1} epochs '
              f'in {time.time() - start_time:.2f} seconds')
        utils.show_logs("Training loss", logs_train)
        utils.show_logs("Validation loss", logs_val)
        print('_'*50)
        print('')

        if logs_val["locAcc_val"] > best_acc:
            best_state = deepcopy(fl.get_module(feature_maker).state_dict())
            best_acc = logs_val["locAcc_val"]

        logs["epoch"].append(epoch)
        for key, value in dict(logs_train, **logs_val).items():
            if key not in logs:
                logs[key] = [None for x in range(epoch)]
            if isinstance(value, np.ndarray):
                value = value.tolist()
            logs[key].append(value)

        if (epoch % logs["saveStep"] == 0 and epoch > 0) or epoch == n_epochs - 1:
            model_state_dict = fl.get_module(feature_maker).state_dict()
            criterion_state_dict = fl.get_module(criterion).state_dict()

            fl.save_checkpoint(model_state_dict, criterion_state_dict,
                               optimizer.state_dict(), best_state,
                               f"{path_checkpoint}_{epoch}.pt")
            utils.save_logs(logs, f"{path_checkpoint}_logs.json")
Example #2
0
    #logs = {"epoch": [], "iter": [], "saveStep": -1} # saveStep=-1, save only best checkpoint!
    logs_test = test_step(feat_gen, clf, test_loader, optimizer, ep)
    print('')
    print('_' * 50)
    print(f'Ran {ep + 1} epochs ' f'in {time.time() - start_time:.2f} seconds')
    utils.show_logs("Training loss", logs_train)
    utils.show_logs("Test loss", logs_test)
    print('_' * 50)
    print('')

    if logs_test["locAcc_test"] > best_acc:
        best_state = deepcopy(fl.get_module(feat_gen).state_dict())
        best_acc = logs_test["locAcc_test"]

    logs["epoch"].append(ep)
    for key, value in dict(logs_train, **logs_test).items():
        if key not in logs:
            logs[key] = [None for x in range(ep)]
        if isinstance(value, np.ndarray):
            value = value.tolist()
        logs[key].append(value)

    if (ep % 10 == 0 and ep > 0) or ep == MAX_EPOCHS - 1:
        feat_gen_state_dict = fl.get_module(feat_gen).state_dict()
        clf_state_dict = fl.get_module(clf).state_dict()

        fl.save_checkpoint(feat_gen_state_dict, clf_state_dict,
                           optimizer.state_dict(), best_state,
                           f"{SAVE_PATH}_{ep}.pt")
        utils.save_logs(logs, f"{SAVE_PATH}_logs.json")
Example #3
0
def run(trainDataset,
        valDataset,
        captureDatasetWithOptions,
        linsepClassificationTaskConfig,
        batchSize,
        samplingMode,
        cpcModel,
        cpcCriterion,
        nEpoch,
        pathCheckpoint,
        optimizer,
        scheduler,
        logs):

    startEpoch = len(logs["epoch"])
    print(f"Running {nEpoch} epochs, now at {startEpoch}")
    bestAcc = 0
    bestStateDict = None
    start_time = time.time()
    
    captureDataset, captureOptions, captureStatsCollector = captureDatasetWithOptions
    linsepEachEpochs, linsepFun = linsepClassificationTaskConfig
    assert (captureDataset is None and captureOptions is None) \
        or (captureDataset is not None and captureOptions is not None)
    if captureOptions is not None:
        captureEachEpochs = captureOptions['eachEpochs']

    print(f'DS sizes: train {str(len(trainDataset)) if trainDataset is not None else "-"}, '
        f'val {str(len(valDataset)) if valDataset is not None else "-"}, capture '
        f'{str(len(captureDataset)) if captureDataset is not None else "-"}')

    for epoch in range(startEpoch, nEpoch):

        print(f"Starting epoch {epoch}")
        utils.cpu_stats()

        trainLoader = trainDataset.getDataLoader(batchSize, samplingMode,
                                                True, numWorkers=0)
        
        valLoader = valDataset.getDataLoader(batchSize, 'sequential', False,
                                            numWorkers=0)
        
        if captureDataset is not None and epoch % captureEachEpochs == 0:
            captureLoader = captureDataset.getDataLoader(batchSize, 'sequential', False,
                                                numWorkers=0)
        
        print("Training dataset %d batches, Validation dataset %d batches, batch size %d" %
            (len(trainLoader), len(valLoader), batchSize))

        locLogsTrain = trainStep(trainLoader, cpcModel, cpcCriterion,
                                optimizer, scheduler, logs["logging_step"])

        locLogsVal = valStep(valLoader, cpcModel, cpcCriterion)

        if captureDataset is not None and epoch % captureEachEpochs == 0:
            print(f"Capturing data for epoch {epoch}")
            captureStep(captureLoader, cpcModel, cpcCriterion, captureOptions, captureStatsCollector, epoch)

        currentAccuracy = float(locLogsVal["locAcc_val"].mean())
        if currentAccuracy > bestAcc:
            bestStateDict = deepcopy(fl.get_module(cpcModel).state_dict())  

        locLogsLinsep = {}
        # this performs linsep task for the best CPC model up to date
        if linsepEachEpochs is not None and epoch !=0 and epoch % linsepEachEpochs == 0:
            # capturing for current CPC state after this epoch, relying on CPC internal accuracy is vague
            locLogsLinsep = linsepFun(epoch, cpcModel, epoch)

        print(f'Ran {epoch + 1} epochs '
            f'in {time.time() - start_time:.2f} seconds')

        torch.cuda.empty_cache()

        for key, value in dict(locLogsTrain, **locLogsVal, **locLogsLinsep).items():
            if key not in logs:
                logs[key] = [None for x in range(epoch)]
            if isinstance(value, np.ndarray):
                value = value.tolist()
            while len(logs[key]) < len(logs["epoch"]):
                logs[key].append(None)  # for not-every-epoch-logged things
            logs[key].append(value)

        logs["epoch"].append(epoch)

        if pathCheckpoint is not None \
                and (epoch % logs["saveStep"] == 0 or epoch == nEpoch-1):

            modelStateDict = fl.get_module(cpcModel).state_dict()
            criterionStateDict = fl.get_module(cpcCriterion).state_dict()

            fl.save_checkpoint(modelStateDict, criterionStateDict,
                            optimizer.state_dict(), bestStateDict,
                            f"{pathCheckpoint}_{epoch}.pt")
            utils.save_logs(logs, pathCheckpoint + "_logs.json")
def trainLinsepClassification(
        feature_maker,
        criterion,  # combined with classification model before
        train_loader,
        val_loader,
        optimizer,
        path_logs,
        logs_save_step,
        path_best_checkpoint,
        n_epochs,
        cpc_epoch,
        label_key="speaker",
        centerpushSettings=None):

    wasOptimizeCPC = feature_maker.optimize if hasattr(feature_maker, 'optimize') else None
    feature_maker.eval()
    feature_maker.optimize = False

    start_epoch = 0
    best_train_acc = -1
    best_acc = -1
    bect_epoch = -1
    logs = {"epoch": [], "iter": [], "saveStep": logs_save_step}

    start_time = time.time()

    for epoch in range(start_epoch, n_epochs):

        logs_train = train_step(feature_maker, criterion, train_loader,
                                optimizer, label_key, centerpushSettings=centerpushSettings)
        logs_val = val_step(feature_maker, criterion, val_loader, label_key, centerpushSettings=centerpushSettings)
        print('')
        print('_'*50)
        print(f'Ran {epoch + 1} {label_key} classification epochs '
              f'in {time.time() - start_time:.2f} seconds')
        utils.show_logs("Training loss", logs_train)
        utils.show_logs("Validation loss", logs_val)
        print('_'*50)
        print('')

        if logs_val["locAcc_val"] > best_acc:
            best_state_cpc = deepcopy(fl.get_module(feature_maker).state_dict())
            best_state_classif_crit = deepcopy(fl.get_module(criterion).state_dict())
            optimizer_state_best_ep = optimizer.state_dict()
            best_epoch = epoch
            best_acc = logs_val["locAcc_val"]

        if logs_train["locAcc_train"] > best_train_acc:
            best_train_acc = logs_train["locAcc_train"]

        logs["epoch"].append(epoch)
        for key, value in dict(logs_train, **logs_val).items():
            if key not in logs:
                logs[key] = [None for x in range(epoch)]
            if isinstance(value, np.ndarray):
                value = value.tolist()
            logs[key].append(value)

        if (epoch % logs["saveStep"] == 0 and epoch > 0) or epoch == n_epochs - 1:
            model_state_dict = fl.get_module(feature_maker).state_dict()
            criterion_state_dict = fl.get_module(criterion).state_dict()

            # fl.save_checkpoint(model_state_dict, criterion_state_dict,
            #                    optimizer.state_dict(), best_state,
            #                    f"{path_checkpoint}_{epoch}.pt")
            utils.save_logs(logs, f"{path_logs}_logs.json")

    if path_best_checkpoint:
        save_linsep_best_checkpoint(best_state_cpc, best_state_classif_crit,
                        optimizer_state_best_ep,  # TODO check if should save that epoch or last in optimizer
                        os.path.join(path_best_checkpoint, f"{label_key}_classif_best-epoch{best_epoch}-cpc_epoch{cpc_epoch}.pt"))
    feature_maker.optimize = wasOptimizeCPC
    return {'num_epoch_trained': n_epochs,
            'best_val_acc': best_acc,
            'best_train_acc': best_train_acc
            }
Example #5
0
def run(trainDataset, valDataset, batchSize, samplingMode, cpcModel,
        cpcCriterion, nEpoch, pathCheckpoint, optimizer, scheduler, logs):

    print(f"Running {nEpoch} epochs")
    startEpoch = len(logs["epoch"])
    bestAcc = 0
    bestStateDict = None
    start_time = time.time()

    for epoch in range(startEpoch, nEpoch):

        print(f"Starting epoch {epoch}")
        utils.cpu_stats()

        trainLoader = trainDataset.getDataLoader(batchSize,
                                                 samplingMode,
                                                 True,
                                                 numWorkers=0)

        valLoader = valDataset.getDataLoader(batchSize,
                                             'sequential',
                                             False,
                                             numWorkers=0)

        print(
            "Training dataset %d batches, Validation dataset %d batches, batch size %d"
            % (len(trainLoader), len(valLoader), batchSize))

        locLogsTrain = trainStep(trainLoader, cpcModel, cpcCriterion,
                                 optimizer, scheduler, logs["logging_step"])

        locLogsVal = valStep(valLoader, cpcModel, cpcCriterion)

        print(f'Ran {epoch + 1} epochs '
              f'in {time.time() - start_time:.2f} seconds')

        torch.cuda.empty_cache()

        currentAccuracy = float(locLogsVal["locAcc_val"].mean())
        if currentAccuracy > bestAcc:
            bestStateDict = fl.get_module(cpcModel).state_dict()

        for key, value in dict(locLogsTrain, **locLogsVal).items():
            if key not in logs:
                logs[key] = [None for x in range(epoch)]
            if isinstance(value, np.ndarray):
                value = value.tolist()
            logs[key].append(value)

        logs["epoch"].append(epoch)

        if pathCheckpoint is not None \
                and (epoch % logs["saveStep"] == 0 or epoch == nEpoch-1):

            modelStateDict = fl.get_module(cpcModel).state_dict()
            criterionStateDict = fl.get_module(cpcCriterion).state_dict()

            fl.save_checkpoint(modelStateDict, criterionStateDict,
                               optimizer.state_dict(), bestStateDict,
                               f"{pathCheckpoint}_{epoch}.pt")
            utils.save_logs(logs, pathCheckpoint + "_logs.json")