예제 #1
0
def main():

    # setup experiment logging to comet.ml
    if expConfig.LOG_COMETML:
        hyper_params = {"experimentName": expConfig.EXPERIMENT_NAME,
                        "epochs": expConfig.EPOCHS,
                        "batchSize": expConfig.BATCH_SIZE,
                        "channels": expConfig.CHANNELS,
                        "virualBatchsize": expConfig.VIRTUAL_BATCHSIZE}
        expConfig.experiment.log_parameters(hyper_params)
        expConfig.experiment.add_tags([expConfig.EXPERIMENT_NAME, "ID{}".format(expConfig.id)])
        if hasattr(expConfig, "EXPERIMENT_TAGS"): expConfig.experiment.add_tags(expConfig.EXPERIMENT_TAGS)
        print(bcolors.OKGREEN + "Logging to comet.ml" + bcolors.ENDC)
    else:
        print(bcolors.WARNING + "Not logging to comet.ml" + bcolors.ENDC)

    # log parameter count
    if expConfig.LOG_PARAMCOUNT:
        paramCount = sum(p.numel() for p in expConfig.net.parameters() if p.requires_grad)
        print("Parameters: {:,}".format(paramCount).replace(",", "'"))

    #load data
    randomCrop = expConfig.RANDOM_CROP if hasattr(expConfig, "RANDOM_CROP") else None
    trainset = bratsDataset.BratsDataset(systemsetup.BRATS_PATH, expConfig, mode="train", randomCrop=randomCrop)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=expConfig.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=expConfig.DATASET_WORKERS)

    valset = bratsDataset.BratsDataset(systemsetup.BRATS_PATH, expConfig, mode="validation")
    valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False, pin_memory=True, num_workers=expConfig.DATASET_WORKERS)

    challengeValset = bratsDataset.BratsDataset(systemsetup.BRATS_VAL_PATH, expConfig, mode="validation", hasMasks=False, returnOffsets=True)
    challengeValloader = torch.utils.data.DataLoader(challengeValset, batch_size=1, shuffle=False, pin_memory=True, num_workers=expConfig.DATASET_WORKERS)

    seg = segmenter.Segmenter(expConfig, trainloader, valloader, challengeValloader)
    if hasattr(expConfig, "VALIDATE_ALL") and expConfig.VALIDATE_ALL:
        seg.validateAllCheckpoints()
    elif hasattr(expConfig, "PREDICT") and expConfig.PREDICT:
        seg.makePredictions()
    elif hasattr(expConfig, "VISUALIZE_PROB_MAP") and expConfig.VISUALIZE_PROB_MAP:
        seg.visualize_prob_maps()
    else:
        seg.train()
예제 #2
0
import os
import torch
import bratsDataset
import segmenter
import systemsetup
import matplotlib.pyplot as plt
from dataProcessing import utils
import numpy as np

import experiments.noNewNet as expConfig

trainset = bratsDataset.BratsDataset(systemsetup.BRATS_PATH,
                                     expConfig,
                                     mode="train")
valset = bratsDataset.BratsDataset(systemsetup.BRATS_PATH,
                                   expConfig,
                                   mode="validation")

# sum_inputs = 0
# count = 0
# for i in range(len(trainset)):
#     inputs, pid, labels = trainset[i]
#     print("processing no.{} {}".format(i, pid))
#     inputs = inputs.numpy()
#     sum_inputs += inputs
#     count += 1
# print(count)
#
# for i in range(len(valset)):
#     inputs, pid, labels = valset[i]
#     print("processing no.{} {}".format(i, pid))