def main(): # setup experiment logging to comet.ml if expConfig.LOG_COMETML: hyper_params = {"experimentName": expConfig.EXPERIMENT_NAME, "epochs": expConfig.EPOCHS, "batchSize": expConfig.BATCH_SIZE, "channels": expConfig.CHANNELS, "virualBatchsize": expConfig.VIRTUAL_BATCHSIZE} expConfig.experiment.log_parameters(hyper_params) expConfig.experiment.add_tags([expConfig.EXPERIMENT_NAME, "ID{}".format(expConfig.id)]) if hasattr(expConfig, "EXPERIMENT_TAGS"): expConfig.experiment.add_tags(expConfig.EXPERIMENT_TAGS) print(bcolors.OKGREEN + "Logging to comet.ml" + bcolors.ENDC) else: print(bcolors.WARNING + "Not logging to comet.ml" + bcolors.ENDC) # log parameter count if expConfig.LOG_PARAMCOUNT: paramCount = sum(p.numel() for p in expConfig.net.parameters() if p.requires_grad) print("Parameters: {:,}".format(paramCount).replace(",", "'")) #load data randomCrop = expConfig.RANDOM_CROP if hasattr(expConfig, "RANDOM_CROP") else None trainset = bratsDataset.BratsDataset(systemsetup.BRATS_PATH, expConfig, mode="train", randomCrop=randomCrop) trainloader = torch.utils.data.DataLoader(trainset, batch_size=expConfig.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=expConfig.DATASET_WORKERS) valset = bratsDataset.BratsDataset(systemsetup.BRATS_PATH, expConfig, mode="validation") valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False, pin_memory=True, num_workers=expConfig.DATASET_WORKERS) challengeValset = bratsDataset.BratsDataset(systemsetup.BRATS_VAL_PATH, expConfig, mode="validation", hasMasks=False, returnOffsets=True) challengeValloader = torch.utils.data.DataLoader(challengeValset, batch_size=1, shuffle=False, pin_memory=True, num_workers=expConfig.DATASET_WORKERS) seg = segmenter.Segmenter(expConfig, trainloader, valloader, challengeValloader) if hasattr(expConfig, "VALIDATE_ALL") and expConfig.VALIDATE_ALL: seg.validateAllCheckpoints() elif hasattr(expConfig, "PREDICT") and expConfig.PREDICT: seg.makePredictions() elif hasattr(expConfig, "VISUALIZE_PROB_MAP") and expConfig.VISUALIZE_PROB_MAP: seg.visualize_prob_maps() else: seg.train()
import os import torch import bratsDataset import segmenter import systemsetup import matplotlib.pyplot as plt from dataProcessing import utils import numpy as np import experiments.noNewNet as expConfig trainset = bratsDataset.BratsDataset(systemsetup.BRATS_PATH, expConfig, mode="train") valset = bratsDataset.BratsDataset(systemsetup.BRATS_PATH, expConfig, mode="validation") # sum_inputs = 0 # count = 0 # for i in range(len(trainset)): # inputs, pid, labels = trainset[i] # print("processing no.{} {}".format(i, pid)) # inputs = inputs.numpy() # sum_inputs += inputs # count += 1 # print(count) # # for i in range(len(valset)): # inputs, pid, labels = valset[i] # print("processing no.{} {}".format(i, pid))