Пример #1
0
def run(trainDataset,
        valDataset,
        captureDatasetWithOptions,
        linsepClassificationTaskConfig,
        batchSize,
        samplingMode,
        cpcModel,
        cpcCriterion,
        nEpoch,
        pathCheckpoint,
        optimizer,
        scheduler,
        logs):

    startEpoch = len(logs["epoch"])
    print(f"Running {nEpoch} epochs, now at {startEpoch}")
    bestAcc = 0
    bestStateDict = None
    start_time = time.time()
    
    captureDataset, captureOptions, captureStatsCollector = captureDatasetWithOptions
    linsepEachEpochs, linsepFun = linsepClassificationTaskConfig
    assert (captureDataset is None and captureOptions is None) \
        or (captureDataset is not None and captureOptions is not None)
    if captureOptions is not None:
        captureEachEpochs = captureOptions['eachEpochs']

    print(f'DS sizes: train {str(len(trainDataset)) if trainDataset is not None else "-"}, '
        f'val {str(len(valDataset)) if valDataset is not None else "-"}, capture '
        f'{str(len(captureDataset)) if captureDataset is not None else "-"}')

    for epoch in range(startEpoch, nEpoch):

        print(f"Starting epoch {epoch}")
        utils.cpu_stats()

        trainLoader = trainDataset.getDataLoader(batchSize, samplingMode,
                                                True, numWorkers=0)
        
        valLoader = valDataset.getDataLoader(batchSize, 'sequential', False,
                                            numWorkers=0)
        
        if captureDataset is not None and epoch % captureEachEpochs == 0:
            captureLoader = captureDataset.getDataLoader(batchSize, 'sequential', False,
                                                numWorkers=0)
        
        print("Training dataset %d batches, Validation dataset %d batches, batch size %d" %
            (len(trainLoader), len(valLoader), batchSize))

        locLogsTrain = trainStep(trainLoader, cpcModel, cpcCriterion,
                                optimizer, scheduler, logs["logging_step"])

        locLogsVal = valStep(valLoader, cpcModel, cpcCriterion)

        if captureDataset is not None and epoch % captureEachEpochs == 0:
            print(f"Capturing data for epoch {epoch}")
            captureStep(captureLoader, cpcModel, cpcCriterion, captureOptions, captureStatsCollector, epoch)

        currentAccuracy = float(locLogsVal["locAcc_val"].mean())
        if currentAccuracy > bestAcc:
            bestStateDict = deepcopy(fl.get_module(cpcModel).state_dict())  

        locLogsLinsep = {}
        # this performs linsep task for the best CPC model up to date
        if linsepEachEpochs is not None and epoch !=0 and epoch % linsepEachEpochs == 0:
            # capturing for current CPC state after this epoch, relying on CPC internal accuracy is vague
            locLogsLinsep = linsepFun(epoch, cpcModel, epoch)

        print(f'Ran {epoch + 1} epochs '
            f'in {time.time() - start_time:.2f} seconds')

        torch.cuda.empty_cache()

        for key, value in dict(locLogsTrain, **locLogsVal, **locLogsLinsep).items():
            if key not in logs:
                logs[key] = [None for x in range(epoch)]
            if isinstance(value, np.ndarray):
                value = value.tolist()
            while len(logs[key]) < len(logs["epoch"]):
                logs[key].append(None)  # for not-every-epoch-logged things
            logs[key].append(value)

        logs["epoch"].append(epoch)

        if pathCheckpoint is not None \
                and (epoch % logs["saveStep"] == 0 or epoch == nEpoch-1):

            modelStateDict = fl.get_module(cpcModel).state_dict()
            criterionStateDict = fl.get_module(cpcCriterion).state_dict()

            fl.save_checkpoint(modelStateDict, criterionStateDict,
                            optimizer.state_dict(), bestStateDict,
                            f"{pathCheckpoint}_{epoch}.pt")
            utils.save_logs(logs, pathCheckpoint + "_logs.json")
Пример #2
0
def run(trainDataset, valDataset, batchSize, samplingMode, cpcModel,
        cpcCriterion, nEpoch, pathCheckpoint, optimizer, scheduler, logs):

    print(f"Running {nEpoch} epochs")
    startEpoch = len(logs["epoch"])
    bestAcc = 0
    bestStateDict = None
    start_time = time.time()

    for epoch in range(startEpoch, nEpoch):

        print(f"Starting epoch {epoch}")
        utils.cpu_stats()

        trainLoader = trainDataset.getDataLoader(batchSize,
                                                 samplingMode,
                                                 True,
                                                 numWorkers=0)

        valLoader = valDataset.getDataLoader(batchSize,
                                             'sequential',
                                             False,
                                             numWorkers=0)

        print(
            "Training dataset %d batches, Validation dataset %d batches, batch size %d"
            % (len(trainLoader), len(valLoader), batchSize))

        locLogsTrain = trainStep(trainLoader, cpcModel, cpcCriterion,
                                 optimizer, scheduler, logs["logging_step"])

        locLogsVal = valStep(valLoader, cpcModel, cpcCriterion)

        print(f'Ran {epoch + 1} epochs '
              f'in {time.time() - start_time:.2f} seconds')

        torch.cuda.empty_cache()

        currentAccuracy = float(locLogsVal["locAcc_val"].mean())
        if currentAccuracy > bestAcc:
            bestStateDict = fl.get_module(cpcModel).state_dict()

        for key, value in dict(locLogsTrain, **locLogsVal).items():
            if key not in logs:
                logs[key] = [None for x in range(epoch)]
            if isinstance(value, np.ndarray):
                value = value.tolist()
            logs[key].append(value)

        logs["epoch"].append(epoch)

        if pathCheckpoint is not None \
                and (epoch % logs["saveStep"] == 0 or epoch == nEpoch-1):

            modelStateDict = fl.get_module(cpcModel).state_dict()
            criterionStateDict = fl.get_module(cpcCriterion).state_dict()

            fl.save_checkpoint(modelStateDict, criterionStateDict,
                               optimizer.state_dict(), bestStateDict,
                               f"{pathCheckpoint}_{epoch}.pt")
            utils.save_logs(logs, pathCheckpoint + "_logs.json")