Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSz', type=int, default=10)
    parser.add_argument('--patch', type=int, default=114)
    parser.add_argument('--train-root', type=str)
    parser.add_argument('--dice', action='store_true')
    parser.add_argument('--ngpu', type=int, default=1)
    parser.add_argument('--nEpochs', type=int, default=300)
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('-i',
                        '--inference',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='run inference on data set and save results')

    # 1e-8 works well for lung masks but seems to prevent
    # rapid learning for nodule masks
    parser.add_argument('--weight-decay',
                        '--wd',
                        default=1e-8,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-8)')
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--lr', default=1e-1, type=float)
    parser.add_argument('--save')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--opt',
                        type=str,
                        default='adam',
                        choices=('sgd', 'adam', 'rmsprop'))
    args = parser.parse_args()
    print(args)
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = args.save or 'work/vnet.base.{}'.format(datestr())
    weight_decay = args.weight_decay
    if args.train_root == '':
        print("error: please print the data path")
        exit()

    setproctitle.setproctitle(args.save)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    batch_size = args.ngpu * args.batchSz
    model = vnet.VNet(classes=23, batch_size=batch_size)
    gpu_ids = range(args.ngpu)
    model = nn.parallel.DataParallel(model, device_ids=gpu_ids)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        model.apply(weights_init)
    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    if args.cuda:
        model = model.cuda()
        print("cuda done")

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    print("loading training set")
    trainSet = CtDataset(args.train_root,
                         classes=23,
                         mode="train",
                         patch_size=[args.patch, args.patch, args.patch],
                         new_space=[3, 3, 3],
                         winw=350,
                         winl=50)
    trainLoader = data.DataLoader(trainSet,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  **kwargs)
    if args.opt == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.99,
                              weight_decay=weight_decay)
    elif args.opt == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=args.lr,
                                  weight_decay=weight_decay)
    best_loss = 1
    for epoch in range(1, args.nEpochs + 1):
        print("Epoch {}:".format(epoch))
        train_loss = train(args, epoch, model, trainLoader, optimizer,
                           batch_size)
        is_best = False
        if train_loss < best_loss:
            is_best = True
            best_loss = train_loss
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': train_loss
            }, is_best, args.save, "vnet_coarse")
Пример #2
0
def main():
    # region params
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSz', type=int, default=1)
    parser.add_argument('--dice', action='store_true', default=True)
    parser.add_argument('--ngpu', type=int, default=1)
    parser.add_argument('--nEpochs', type=int, default=10)
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('-i',
                        '--inference',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='run inference on data set and save results')

    # 1e-8 works well for lung masks but seems to prevent
    # rapid learning for nodule masks
    parser.add_argument('--weight-decay',
                        '--wd',
                        default=1e-8,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-8)')
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--save')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--opt',
                        type=str,
                        default='adam',
                        choices=('sgd', 'adam', 'rmsprop'))
    args = parser.parse_args()
    # endregion

    best_prec1 = 100.
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = args.save or 'work/vnet.base.{}'.format(datestr())
    nll = True  #是否使用逻辑softmax
    if args.dice:
        nll = False
    weight_decay = args.weight_decay
    setproctitle.setproctitle(args.save)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    model = vnet.VNet(elu=True, nll=nll)
    batch_size = args.ngpu * args.batchSz
    #不支持多显卡运行
    # model = nn.parallel.DataParallel(model)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        model.apply(weights_init)

    if nll:
        train = train_nll
        test = test_nll
        class_balance = True
    else:
        train = train_dice
        test = test_dice
        class_balance = False

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    if args.cuda:
        model = model.cuda()

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)

    # LUNA16 dataset isotropically scaled to 2.5mm^3
    # and then truncated or zero-padded to 160x128x160
    normMu = [-642.794]
    normSigma = [459.512]
    normTransform = transforms.Normalize(normMu, normSigma)

    trainTransform = transforms.Compose([transforms.ToTensor(), normTransform])
    testTransform = transforms.Compose([transforms.ToTensor(), normTransform])
    if ct_targets == nodule_masks:
        masks = lung_masks
    else:
        masks = None


#先不进行推理
    if args.inference != '':
        if not args.resume:
            print("args.resume must be set to do inference")
            exit(1)
        kwargs = {'num_workers': 1} if args.cuda else {}
        src = args.inference
        dst = args.save
        inference_batch_size = args.ngpu
        root = os.path.dirname(src)
        images = os.path.basename(src)
        dataset = dset.LUNA16(root=root,
                              images=images,
                              transform=testTransform,
                              split=target_split,
                              mode="infer")
        loader = DataLoader(dataset,
                            batch_size=inference_batch_size,
                            shuffle=False,
                            collate_fn=noop,
                            **kwargs)
        inference(args, loader, model, trainTransform)
        return

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    print("loading training set")
    trainSet = dset.LUNA16(root='luna16',
                           images=ct_images,
                           targets=ct_targets,
                           mode="train",
                           transform=trainTransform,
                           class_balance=class_balance,
                           split=target_split,
                           seed=args.seed,
                           masks=masks)
    trainLoader = DataLoader(trainSet,
                             batch_size=batch_size,
                             shuffle=True,
                             **kwargs)
    print("loading test set")
    testSet = dset.LUNA16(root='luna16',
                          images=ct_images,
                          targets=ct_targets,
                          mode="test",
                          transform=testTransform,
                          seed=args.seed,
                          masks=masks,
                          split=target_split)
    testLoader = DataLoader(testSet,
                            batch_size=batch_size,
                            shuffle=False,
                            **kwargs)

    #取消权重均衡
    # target_mean = trainSet.target_mean()
    # bg_weight = target_mean / (1. + target_mean)
    # fg_weight = 1. - bg_weight
    # print(bg_weight)
    # class_weights = torch.FloatTensor([bg_weight, fg_weight])

    class_weights = torch.FloatTensor([0.5, 0.5])
    if args.cuda:
        class_weights = class_weights.cuda()

    if args.opt == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=1e-1,
                              momentum=0.99,
                              weight_decay=weight_decay)
    elif args.opt == 'adam':
        optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  weight_decay=weight_decay)
    else:  #默认adam
        optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)

    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')
    err_best = 100.
    # best_prec1 = 100.

    for epoch in range(1, args.nEpochs + 1):
        adjust_opt(args.opt, optimizer, epoch)

        train(args, epoch, model, trainLoader, optimizer, trainF,
              class_weights)

        err = test(args, epoch, model, testLoader, optimizer, testF,
                   class_weights)

        is_best = False
        if err < best_prec1:
            is_best = True
            best_prec1 = err
        # save_checkpoint({'epoch': epoch, 'state_dict': model.state_dict(),  'best_prec1': best_prec1}, is_best, args.save, "vnet")
        save_checkpoint(
            {
                'epoch': epoch,
                'model': model,
                'best_prec1': best_prec1
            }, is_best, args.save, "vnet")
        os.system('./plot.py {} {} &'.format(len(trainLoader), args.save))

    trainF.close()
    testF.close()
Пример #3
0
def main(params, args):
    best_prec1 = 100.  # accuracy? by Chao
    epochs = args.nEpochs
    nr_iter = args.numIterations  # params['ModelParams']['numIterations']
    batch_size = args.batchsize  # params['ModelParams']['batchsize']
    resultDir = 'results/vnet.base.{}.{}'.format(params['ModelParams']['task'],
                                                 datestr())

    weight_decay = args.weight_decay
    setproctitle.setproctitle(resultDir)
    if os.path.exists(resultDir):
        shutil.rmtree(resultDir)
    os.makedirs(resultDir, exist_ok=True)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    model = vnet.VNet(elu=False, nll=False)
    gpu_ids = args.gpu_ids
    # torch.cuda.set_device(gpu_ids) # why do I have to add this line? It seems the below line is useless to apply GPU devices. By Chao.
    # model = nn.parallel.DataParallel(model, device_ids=[gpu_ids])
    model = nn.parallel.DataParallel(model)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        model.apply(weights_init)

    train = train_dice
    test = test_dice

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    if args.cuda:
        model = model.cuda()

    # transform
    trainTransform = transforms.Compose([transforms.ToTensor()])
    testTransform = transforms.Compose([transforms.ToTensor()])

    if args.opt == 'sgd':
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.baseLR,
            momentum=args.momentum,
            weight_decay=weight_decay)  # params['ModelParams']['baseLR']
    elif args.opt == 'adam':
        optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  weight_decay=weight_decay)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

    # pdb.set_trace()
    DataManagerParams = {
        'dstRes': np.asarray(eval(args.dstRes), dtype=float),
        'VolSize': np.asarray(eval(args.VolSize), dtype=int),
        'normDir': params['DataManagerParams']['normDir']
    }

    if params['ModelParams'][
            'dirTestImage']:  # if exists, means test files are given.
        print("\nloading training set")
        dataManagerTrain = DM.DataManager(
            params['ModelParams']['dirTrainImage'],
            params['ModelParams']['dirTrainLabel'],
            params['ModelParams']['dirResult'], DataManagerParams)
        dataManagerTrain.loadTrainingData()  # required
        train_images = dataManagerTrain.getNumpyImages()
        train_labels = dataManagerTrain.getNumpyGT()

        print("\nloading test set")
        dataManagerTest = DM.DataManager(params['ModelParams']['dirTestImage'],
                                         params['ModelParams']['dirTestLabel'],
                                         params['ModelParams']['dirResult'],
                                         DataManagerParams)
        dataManagerTest.loadTestingData()  # required
        test_images = dataManagerTest.getNumpyImages()
        test_labels = dataManagerTest.getNumpyGT()

        testSet = customDataset.customDataset(mode='test',
                                              images=test_images,
                                              GT=test_labels,
                                              transform=testTransform)
        testLoader = DataLoader(testSet, batch_size=1, shuffle=True, **kwargs)

    elif args.testProp:  # if 'dirTestImage' is not given but 'testProp' is given, means only one data set is given. need to perform train_test_split.
        print('\n loading dataset, will split into train and test')
        dataManager = DM.DataManager(params['ModelParams']['dirTrainImage'],
                                     params['ModelParams']['dirTrainLabel'],
                                     params['ModelParams']['dirResult'],
                                     DataManagerParams)
        dataManager.loadTrainingData()  # required
        numpyImages = dataManager.getNumpyImages()
        numpyGT = dataManager.getNumpyGT()
        # pdb.set_trace()

        train_images, train_labels, test_images, test_labels = train_test_split(
            numpyImages, numpyGT, args.testProp)
        testSet = customDataset.customDataset(mode='test',
                                              images=test_images,
                                              GT=test_labels,
                                              transform=testTransform)
        testLoader = DataLoader(testSet, batch_size=1, shuffle=True, **kwargs)

    else:  # if both 'dirTestImage' and 'testProp' are not given, means the only one dataset provided is used as train set.
        print('\n loading only train dataset')
        dataManager = DM.DataManager(params['ModelParams']['dirTrainImage'],
                                     params['ModelParams']['dirTrainLabel'],
                                     params['ModelParams']['dirResult'],
                                     DataManagerParams)
        dataManager.loadTrainingData()  # required
        train_images = dataManager.getNumpyImages()
        train_labels = dataManager.getNumpyGT()

        test_images = None
        test_labels = None
        testSet = None
        testLoader = None

    if params['ModelParams']['dirTestImage']:
        dataManager_toTestFunc = dataManagerTest
    else:
        dataManager_toTestFunc = dataManager

    ### For train_images and train_labels, starting data augmentation and loading augmented data with multiprocessing
    dataQueue = Queue(30)  # max 30 images in queue?
    dataPreparation = [None] * params['ModelParams']['nProc']

    # processes creation
    for proc in range(0, params['ModelParams']['nProc']):
        dataPreparation[proc] = Process(target=dataAugmentation,
                                        args=(params, args, dataQueue,
                                              train_images, train_labels))
        dataPreparation[proc].daemon = True
        dataPreparation[proc].start()

    batchData = np.zeros(
        (batch_size, DataManagerParams['VolSize'][0],
         DataManagerParams['VolSize'][1], DataManagerParams['VolSize'][2]),
        dtype=float)
    batchLabel = np.zeros(
        (batch_size, DataManagerParams['VolSize'][0],
         DataManagerParams['VolSize'][1], DataManagerParams['VolSize'][2]),
        dtype=float)

    trainF = open(os.path.join(resultDir, 'train.csv'), 'w')
    testF = open(os.path.join(resultDir, 'test.csv'), 'w')

    for epoch in range(1, epochs + 1):
        dataQueue_tmp = dataQueue  # not working from epoch = 2 and so on. why??? By Chao.
        diceOvBatch = 0
        err = 0
        for iteration in range(1, nr_iter + 1):
            # adjust_opt(args.opt, optimizer, iteration+)
            if args.opt == 'sgd':
                if np.mod(iteration, args.stepsize) == 0:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] *= args.gamma

            for i in range(batch_size):
                [defImg, defLab] = dataQueue_tmp.get()

                batchData[i, :, :, :] = defImg.astype(dtype=np.float32)
                batchLabel[i, :, :, :] = (defLab > 0.5).astype(
                    dtype=np.float32)

            trainSet = customDataset.customDataset(mode='train',
                                                   images=batchData,
                                                   GT=batchLabel,
                                                   transform=trainTransform)
            trainLoader = DataLoader(trainSet,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     **kwargs)

            diceOvBatch_tmp, err_tmp = train(args, epoch, iteration, model,
                                             trainLoader, optimizer, trainF)

            if args.xLabel == 'Iteration':
                trainF.write('{},{},{}\n'.format(iteration, diceOvBatch_tmp,
                                                 err_tmp))
                trainF.flush()
            elif args.xLabel == 'Epoch':
                diceOvBatch += diceOvBatch_tmp
                err += err_tmp
        if args.xLabel == 'Epoch':
            trainF.write('{},{},{}\n'.format(epoch, diceOvBatch / nr_iter,
                                             err / nr_iter))
            trainF.flush()

        if np.mod(epoch,
                  epochs) == 0:  # default to set last epoch to save checkpoint
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1
                },
                path=resultDir,
                prefix="vnet_epoch{}".format(epoch))
        if epoch == epochs and testLoader:
            test(dataManager_toTestFunc, args, epoch, model, testLoader, testF,
                 resultDir)  # by Chao.

    os.system('./plot.py {} {} &'.format(args.xLabel, resultDir))

    trainF.close()
    testF.close()

    # inference, i.e. output predicted mask for test data in .mhd
    if params['ModelParams']['dirInferImage'] != '':
        print("loading inference data")
        dataManagerInfer = DM.DataManager(
            params['ModelParams']['dirInferImage'], None,
            params['ModelParams']['dirResult'], DataManagerParams)
        dataManagerInfer.loadInferData(
        )  # required.  Create .loadInferData??? by Chao.
        numpyImages = dataManagerInfer.getNumpyImages()

        inferSet = customDataset.customDataset(mode='infer',
                                               images=numpyImages,
                                               GT=None,
                                               transform=testTransform)
        inferLoader = DataLoader(inferSet,
                                 batch_size=1,
                                 shuffle=True,
                                 **kwargs)
        inference(dataManagerInfer, args, inferLoader, model, resultDir)
Пример #4
0
def train():
    # Load data and prepare training samples
    numpyImages, numpyGT = load_data()
    dataQueue = Queue(30)  # max 50 images in queue
    dataPreparation = [None] * cfg.nProc

    # thread creation
    for proc in range(cfg.nProc):
        dataPreparation[proc] = Process(target=prepare_data_thread,
                                        args=(dataQueue, numpyImages, numpyGT))
        dataPreparation[proc].daemon = True
        dataPreparation[proc].start()

    def data_gen():
        for _ in range(cfg.numIterations * cfg.batchSize):
            defImg, defLab, _ = dataQueue.get()
            yield defImg, defLab

    print("Load data.")
    # tensorflow data loader
    h, w, d = params["VolSize"]
    dataset = tf.data.Dataset.from_generator(
        data_gen, (tf.float32, tf.int32),
        (tf.TensorShape([h, w, d, 1]), tf.TensorShape([h, w, d])))
    dataset = dataset.batch(batch_size=cfg.batchSize)

    print("Build model.")
    # build model
    model = vnet.VNet([h, w, d, 1], cfg.batchSize, cfg.ncls)
    learning_rate = cfg.baseLR
    learning_rate = K.optimizers.schedules.ExponentialDecay(
        learning_rate, cfg.decay_steps, cfg.decay_rate, True)
    optim = K.optimizers.SGD(learning_rate, momentum=0.99)
    criterion = K.losses.SparseCategoricalCrossentropy(from_logits=True)

    @tf.function
    def train_step(x, y):
        # Forward
        with tf.GradientTape() as tape:
            prediction = model(x)
            losses = criterion(y, prediction)
        # Backward
        with tf.name_scope("Gradients"):
            gradients = tape.gradient(losses, model.trainable_variables)
        optim.apply_gradients(zip(gradients, model.trainable_variables))
        return losses, prediction

    # File writer
    writer, logdir = utils.summary_writer(cfg)
    # Trace graph
    tf.summary.trace_on(graph=True)
    train_step(tf.zeros([1, h, w, d, 1]),
               tf.zeros([1, h, w, d]))  # dry run for tracing graph (step=1)
    tf.summary.trace_export("OpGraph", 0)

    print("Start training.")
    save_path = logdir / "snapshots"
    total_loss = 0
    dice = None
    for trImg, trLab in dataset:
        loss, pred = train_step(trImg, trLab)
        step = optim.iterations.numpy()  # (step start from 2)
        loss_val = loss.numpy()

        # Loss moving average
        total_loss = loss_val if step < 5 else \
            cfg.moving_average * total_loss + (1 - cfg.moving_average) * loss_val

        # Logging
        if (step < 500 and step % 10 == 0) or step % cfg.log_interval == 0:
            dice = utils.compute_dice(trLab, pred)
            print(f"Step: {step}, Loss: {loss_val:.4f}, Dice: {dice:.4f}, "
                  f"LR: {learning_rate(step).numpy():.2E}")

            # Summary scalars and images
            tf.summary.scalar("loss", total_loss, step=step)
            tf.summary.scalar("dice", dice, step=step)
            tf.summary.image("trImg", trImg[..., d // 2, :], step=step)
            tf.summary.image("pred", pred[..., d // 2, :], step=step)

        # Take snapshots
        if step == 2 or step % cfg.snap_shot_interval == 0:
            filepath = utils.snapshot(model, save_path, step)
            print(f"Model weights saved (Path: {filepath}).")

    # Ending
    filepath = utils.snapshot(model, save_path, optim.iterations.numpy())
    print(f"Model weights saved ({filepath}).\nTraining ended.")
    writer.close()
Пример #5
0
def main(params, args):
    ###############      NOTE: COMMON PART STARTS        ##################
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py#L139
    # best_prec1 sort of can be seen here in above link as best_acc1.
    # This is used to keep track of best_acc1 achieved yet in the checkpoints
    best_prec1 = 100.  # accuracy? by Chao
    epochs = args.nEpochs
    nr_iter = args.numIterations  # params['ModelParams']['numIterations']
    batch_size = args.batchsize  # params['ModelParams']['batchsize']
    task = params['ModelParams']['task']

    # for every run, a folder is created and this is how it gets its name
    resultDir = 'results/vnet.base.{}.{}'.format(
        task, datestr())


    # https://becominghuman.ai/this-thing-called-weight-decay-a7cd4bcfccab
    weight_decay = args.weight_decay

    # https://pypi.org/project/setproctitle/
    # The setproctitle module allows a process to change its title (as displayed by system tools such as ps and top).
    # set title of the current process
    setproctitle.setproctitle(resultDir)

    # https://docs.python.org/3/library/shutil.html#shutil.rmtree
    if os.path.exists(resultDir):
        shutil.rmtree(resultDir)
    os.makedirs(resultDir, exist_ok=True)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # https://discuss.pytorch.org/t/what-is-manual-seed/5939/4
    # You just need to call torch.manual_seed(seed), and it will set the seed of the random number generator to a fixed value,
    # so that when you call for example torch.rand(2), the results will be reproducible.
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    model = vnet.VNet(elu=False, nll=False)
    
    gpu_ids = args.gpu_ids
    # torch.cuda.set_device(gpu_ids) # why do I have to add this line? It seems the below line is useless to apply GPU devices. By Chao.
    # model = nn.parallel.DataParallel(model, device_ids=[gpu_ids])
    model = nn.parallel.DataParallel(model)

    global DM
    # NOTE: Change the data manager according to task at hand
    if task == 'nci-isbi-2013':
        DM = DCM
    else:
        DM = DM
    ###############      NOTE: COMMON PART ENDS       ##################

    if not args.testonly:
        # either resume model training - in which case, pass the path to checkpoint
        # or declare initial weights
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                
                # A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor.
                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                    .format(args.evaluate, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))
        else:

            # https://stackoverflow.com/questions/49433936/how-to-initialize-weights-in-pytorch
            # https://discuss.pytorch.org/t/parameters-initialisation/20001
            model.apply(weights_init)

        train = train_dice
        test = test_dice

        print('  + Number of params: {}'.format(
            sum([p.data.nelement() for p in model.parameters()])))
        if args.cuda:
            model = model.cuda()

        # https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.functional.to_tensor
        # Convert a PIL Image or numpy.ndarray to tensor
        trainTransform = transforms.Compose([
            transforms.ToTensor()
        ])
        testTransform = transforms.Compose([
            transforms.ToTensor()
        ])

        # setting optimiser from argument
        if args.opt == 'sgd':
            optimizer = optim.SGD(model.parameters(), lr=args.baseLR,
                                momentum=args.momentum, weight_decay=weight_decay)  # params['ModelParams']['baseLR']
        elif args.opt == 'adam':
            optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
        elif args.opt == 'rmsprop':
            optimizer = optim.RMSprop(
                model.parameters(), weight_decay=weight_decay)

        kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

        # pdb.set_trace()
        DataManagerParams = {
            'dstRes': np.asarray(eval(args.dstRes), dtype=float),
            'VolSize': np.asarray(eval(args.VolSize), dtype=int),
            'normDir': params['DataManagerParams']['normDir']
        }

        # if exists, means test files are given.
        if params['ModelParams']['dirTestImage']:
            print("\nloading training set")
            dataManagerTrain = DM.DataManager(params['ModelParams']['dirTrainImage'],
                                            params['ModelParams']['dirTrainLabel'],
                                            params['ModelParams']['dirResult'],
                                            DataManagerParams)
            dataManagerTrain.loadTrainingData()  # required
            train_images = dataManagerTrain.getNumpyImages()
            train_labels = dataManagerTrain.getNumpyGT()

            print("\nloading test set")
            dataManagerTest = DM.DataManager(params['ModelParams']['dirTestImage'], params['ModelParams']['dirTestLabel'],
                                            params['ModelParams']['dirResult'],
                                            DataManagerParams)
            dataManagerTest.loadTestingData()  # required
            test_images = dataManagerTest.getNumpyImages()
            test_labels = dataManagerTest.getNumpyGT()

            
            testSet = customDataset.customDataset(
                mode='test',
                images=test_images,
                GT=test_labels,

                task=task,
                
                # testTransform is using pytorch transform, just to convert ndarray to a tensor
                # REVIEW: shouldn't we be setting both transform and GT_transform?
                # to remind - the transformation is just converting it to tensors
                transform=testTransform
            )
            
            testLoader = DataLoader(testSet, batch_size=1, shuffle=True, **kwargs)

        elif args.testProp:  # if 'dirTestImage' is not given but 'testProp' is given, means only one data set is given. need to perform train_test_split.
            print('\n loading dataset, will split into train and test')
            dataManager = DM.DataManager(params['ModelParams']['dirTrainImage'],
                                        params['ModelParams']['dirTrainLabel'],
                                        params['ModelParams']['dirResult'],
                                        DataManagerParams)
            dataManager.loadTrainingData()  # required
            numpyImages = dataManager.getNumpyImages()
            numpyGT = dataManager.getNumpyGT()
            # pdb.set_trace()

            train_images, train_labels, test_images, test_labels = train_test_split(
                numpyImages, numpyGT, args.testProp)
            testSet = customDataset.customDataset(
                mode='test', images=test_images, task=task, GT=test_labels, transform=testTransform)
            testLoader = DataLoader(testSet, batch_size=1, shuffle=True, **kwargs)

        else:  # if both 'dirTestImage' and 'testProp' are not given, means the only one dataset provided is used as train set.
            print('\n loading only train dataset')
            dataManager = DM.DataManager(params['ModelParams']['dirTrainImage'],
                                        params['ModelParams']['dirTrainLabel'],
                                        params['ModelParams']['dirResult'],
                                        DataManagerParams)
            dataManager.loadTrainingData()  # required
            train_images = dataManager.getNumpyImages()
            train_labels = dataManager.getNumpyGT()

            test_images = None
            test_labels = None
            testSet = None
            testLoader = None

        if params['ModelParams']['dirTestImage']:
            dataManager_toTestFunc = dataManagerTest
        else:
            dataManager_toTestFunc = dataManager

        ### For train_images and train_labels, starting data augmentation and loading augmented data with multiprocessing
        dataQueue = Queue(30)  # max 30 images in queue?
        dataPreparation = [None] * params['ModelParams']['nProc']

        # processes creation
        for proc in range(0, params['ModelParams']['nProc']):
            # the dataAugmentation processes put the augmented training images in the dataQueue
            dataPreparation[proc] = Process(target=dataAugmentation,
                                            args=(params, args, dataQueue, train_images, train_labels))
            dataPreparation[proc].daemon = True
            dataPreparation[proc].start()

        batchData = np.zeros((batch_size, DataManagerParams['VolSize'][0],
                            DataManagerParams['VolSize'][1],
                            DataManagerParams['VolSize'][2]), dtype=float)
        batchLabel = np.zeros((batch_size, DataManagerParams['VolSize'][0],
                            DataManagerParams['VolSize'][1],
                            DataManagerParams['VolSize'][2]), dtype=float)

        trainF = open(os.path.join(resultDir, 'train.csv'), 'w')
        testF = open(os.path.join(resultDir, 'test.csv'), 'w')

        print(torch.cuda.is_available())

        for epoch in range(1, epochs+1):
            # not working from epoch = 2 and so on. why??? By Chao.
            dataQueue_tmp = dataQueue
            diceOvBatch = 0
            err = 0
            for iteration in range(1, nr_iter + 1):
                # adjust_opt(args.opt, optimizer, iteration+)
                if args.opt == 'sgd':
                    if np.mod(iteration, args.stepsize) == 0:
                        for param_group in optimizer.param_groups:
                            param_group['lr'] *= args.gamma

                for i in range(batch_size):
                    [defImg, defLab] = dataQueue_tmp.get()

                    batchData[i, :, :, :] = defImg.astype(dtype=np.float32)
                    batchLabel[i, :, :, :] = (
                        defLab > 0.5).astype(dtype=np.float32)

                trainSet = customDataset.customDataset(mode='train', images=batchData, GT=batchLabel,
                                                    task=task,
                                                    transform=trainTransform)
                trainLoader = DataLoader(
                    trainSet, batch_size=batch_size, shuffle=True, **kwargs)

                diceOvBatch_tmp, err_tmp = train(
                    args, epoch, iteration, model, trainLoader, optimizer, trainF)

                if args.xLabel == 'Iteration':
                    trainF.write('{},{},{}\n'.format(
                        iteration, diceOvBatch_tmp, err_tmp))
                    trainF.flush()
                elif args.xLabel == 'Epoch':
                    diceOvBatch += diceOvBatch_tmp
                    err += err_tmp
            if args.xLabel == 'Epoch':
                trainF.write('{},{},{}\n'.format(
                    epoch, diceOvBatch/nr_iter, err/nr_iter))
                trainF.flush()

            if np.mod(epoch, epochs) == 0:  # default to set last epoch to save checkpoint
                save_checkpoint({'epoch': epoch,
                                'state_dict': model.state_dict(),
                                'best_prec1': best_prec1}, path=resultDir, prefix="vnet_epoch{}".format(epoch))
            if epoch == epochs and testLoader:
                # by Chao.
                test(dataManager_toTestFunc, args, epoch,
                    model, testLoader, testF, resultDir)

        os.system('./plot.py {} {} &'.format(args.xLabel, resultDir))

        trainF.close()
        testF.close()

        # inference, i.e. output predicted mask for test data
        if params['ModelParams']['dirInferImage'] != '':
            print("loading inference data")
            dataManagerInfer = DM.DataManager(params['ModelParams']['dirInferImage'], None,
                                            params['ModelParams']['dirResult'],
                                            DataManagerParams)
            # required.  Create .loadInferData??? by Chao.
            dataManagerInfer.loadInferData()
            numpyImages = dataManagerInfer.getNumpyImages()

            inferSet = customDataset.customDataset(
                mode='infer', images=numpyImages, task=task, GT=None, transform=testTransform)

            inferLoader = DataLoader(inferSet, batch_size=1, shuffle=True, **kwargs)

            inference(dataManagerInfer, args, inferLoader, model, resultDir)
    else:
        print(f"Only running testing on the test dataset of '{params['ModelParams']['task']}' using model saved at '{args.testonly}'")

        # BUG: Initially this will work only for the case of using trained model of promise12 on nci, because nci has a clearer test data with labels,
        # the accuracy of which would be easier to compare
        
        # REVIEW: All experimental below
        assert not args.resume, "Cannot resume training when only testing. Remone one of the resume or testonly flags"
        
        model_path = args.testonly
        
        # load model
        if os.path.isfile(model_path):
            print("=> loading checkpoint '{}'".format(model_path))
            checkpoint = torch.load(model_path)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            
            # A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor.
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                .format(model_path, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit()

        test = test_dice

        print('  + Number of params: {}'.format(
            sum([p.data.nelement() for p in model.parameters()])))
        if args.cuda:
            model = model.cuda()

        testTransform = transforms.Compose([
            transforms.ToTensor()
        ])

        kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

        # pdb.set_trace()
        DataManagerParams = {
            'dstRes': np.asarray(eval(args.dstRes), dtype=float),
            'VolSize': np.asarray(eval(args.VolSize), dtype=int),
            'normDir': params['DataManagerParams']['normDir']
        }

        # if exists, means test files are given.
        if params['ModelParams']['dirTestImage']:
            print("\nloading test set")
            dataManagerTest = DM.DataManager(params['ModelParams']['dirTestImage'], params['ModelParams']['dirTestLabel'],
                                            params['ModelParams']['dirResult'],
                                            DataManagerParams)
            dataManagerTest.loadTestingData()  # required
            test_images = dataManagerTest.getNumpyImages()
            test_labels = dataManagerTest.getNumpyGT()

            
            testSet = customDataset.customDataset(
                mode='test',
                images=test_images,
                GT=test_labels,

                task=task,
                
                # testTransform is using pytorch transform, just to convert ndarray to a tensor
                # REVIEW: shouldn't we be setting both transform and GT_transform?
                # to remind - the transformation is just converting it to tensors
                transform=testTransform
            )
            
            testLoader = DataLoader(testSet, batch_size=1, shuffle=True, **kwargs)

        elif args.testProp:  # if 'dirTestImage' is not given but 'testProp' is given, means only one data set is given. need to perform train_test_split.
            print('\n loading dataset, will split into train and test')
            dataManager = DM.DataManager(params['ModelParams']['dirTrainImage'],
                                        params['ModelParams']['dirTrainLabel'],
                                        params['ModelParams']['dirResult'],
                                        DataManagerParams)
            dataManager.loadTrainingData()  # required
            numpyImages = dataManager.getNumpyImages()
            numpyGT = dataManager.getNumpyGT()
            # pdb.set_trace()

            train_images, train_labels, test_images, test_labels = train_test_split(
                numpyImages, numpyGT, args.testProp)
            testSet = customDataset.customDataset(
                mode='test', images=test_images, task=task, GT=test_labels, transform=testTransform)
            testLoader = DataLoader(testSet, batch_size=1, shuffle=True, **kwargs)

        else:  # if both 'dirTestImage' and 'testProp' are not given, means the only one dataset provided is used as train set.
            assert False, "There needs to be a test set specified for testonly mode"

        if params['ModelParams']['dirTestImage']:
            dataManager_toTestFunc = dataManagerTest
        else:
            dataManager_toTestFunc = dataManager

        batchData = np.zeros((batch_size, DataManagerParams['VolSize'][0],
                            DataManagerParams['VolSize'][1],
                            DataManagerParams['VolSize'][2]), dtype=float)
        batchLabel = np.zeros((batch_size, DataManagerParams['VolSize'][0],
                            DataManagerParams['VolSize'][1],
                            DataManagerParams['VolSize'][2]), dtype=float)

        testF = open(os.path.join(resultDir, 'test.csv'), 'w')

        print(torch.cuda.is_available())

        epoch = 1
        test(dataManager_toTestFunc, args, epoch,
            model, testLoader, testF, resultDir)

        testF.close()
Пример #6
0
    # sub_volume_im = torch.reshape(sub_volume_im, (1, 1, 96, 96, 96))
    # sub_volume_im = sub_volume_im.type(torch.FloatTensor)
    #
    # sub_volume_label = trainTransform(one_hot)
    # sub_volume_label = torch.reshape(sub_volume_label, (1, 23, -1))
    # sub_volume_label = sub_volume_label.type(torch.FloatTensor)
    # print(sub_volume_im.shape)
    # print(sub_volume_label.shape)
    # exit()
    #
    # model = vnet.VNet(classes=23, batch_size=1).cpu()
    # output = model(sub_volume_im)
    # print(output.shape)



if __name__ == '__main__':
    trainSet = CtDataset("./data/test", 23, "train", [96, 96, 96], [3, 3, 3], 350, 50)
    trainLoader = data.DataLoader(trainSet, batch_size=2, shuffle=False)
    dataiter = iter(trainLoader)
    img, mask = dataiter.next()
    print(mask.shape)
    exit()
    _, axs = plt.subplots(2, 1)
    axs[0].imshow(img[0, 0, 0, :, :], cmap='gray')
    axs[1].imshow(mask[0, 20, 0, :, :], cmap='gray')
    plt.show()
    model = vnet.VNet(classes=23, batch_size=1).cpu()
    output = model(img)
    print(output.shape)
    # test_load()
Пример #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSz', type=int, default=10)
    parser.add_argument('--nll', type=bool, default=True)
    parser.add_argument('--ngpu', type=int, default=1)
    parser.add_argument('--nEpochs', type=int, default=300)
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--weight-decay',
                        '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-4)')
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--save')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--opt',
                        type=str,
                        default='sgd',
                        choices=('sgd', 'adam', 'rmsprop'))
    args = parser.parse_args()
    best_prec1 = 100.
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = args.save or 'work/vnet.base.{}'.format(datestr())
    nll = args.nll
    weight_decay = args.weight_decay
    setproctitle.setproctitle(args.save)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    model = vnet.VNet(elu=False, nll=nll)
    batch_size = args.ngpu * args.batchSz
    if args.ngpu > 1:
        gpu_ids = range(args.ngpu)
        model = nn.parallel.DataParallel(model, device_ids=gpu_ids)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        model.apply(weights_init)

    if nll:
        train = train_nll
        test = test_nll
        class_balance = True
    else:
        train = train_dice
        test = test_dice
        class_balance = False

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    if args.cuda:
        model = model.cuda()

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)

    # LUNA16 dataset isotropically scaled to 2.5mm^3
    # and then truncated or zero-padded to 160x128x160
    normMu = [-642.794]
    normSigma = [459.512]
    normTransform = transforms.Normalize(normMu, normSigma)

    trainTransform = transforms.Compose([transforms.ToTensor(), normTransform])
    testTransform = transforms.Compose([transforms.ToTensor(), normTransform])

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    print("loading training set")
    trainSet = dset.LUNA16(root='luna16',
                           images="luna16_ct_normalized",
                           targets=lung_masks,
                           train=True,
                           transform=trainTransform,
                           allow_empty=False,
                           class_balance=class_balance,
                           split=[2, 2, 2],
                           seed=args.seed)
    trainLoader = DataLoader(trainSet,
                             batch_size=batch_size,
                             shuffle=True,
                             **kwargs)
    print("loading test set")
    testLoader = DataLoader(dset.LUNA16(root='luna16',
                                        images="luna16_ct_normalized",
                                        targets=lung_masks,
                                        train=False,
                                        transform=testTransform,
                                        allow_empty=False,
                                        seed=args.seed,
                                        split=[2, 2, 2]),
                            batch_size=batch_size,
                            shuffle=False,
                            **kwargs)

    target_weight = trainSet.target_weight()
    print(target_weight)
    bg_weight = 1.0 - target_weight
    class_weights = torch.FloatTensor([bg_weight, target_weight])
    if args.cuda:
        class_weights = class_weights.cuda()

    if args.opt == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=1e-1,
                              momentum=0.99,
                              weight_decay=weight_decay)
    elif args.opt == 'adam':
        optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  weight_decay=weight_decay)

    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')
    err_best = 100.
    for epoch in range(1, args.nEpochs + 1):
        adjust_opt(args.opt, optimizer, epoch)
        train(args, epoch, model, trainLoader, optimizer, trainF,
              class_weights)
        err = test(args, epoch, model, testLoader, optimizer, testF,
                   class_weights)
        torch.save(model, os.path.join(args.save, 'vnet_checkpoint.pth'))
        is_best = False
        if err < best_prec1:
            is_best = True
            best_prec1 = err
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            }, is_best, args.save, "vnet")
        os.system('./plot.py {} {} &'.format(len(trainLoader), args.save))

    trainF.close()
    testF.close()
Пример #8
0
def main(params, args):
    best_prec1 = 100.  # accuracy? by Chao
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    resultDir = 'results/vnet.base.{}'.format(datestr())
    nll = True
    if args.dice:
        nll = False
    weight_decay = args.weight_decay
    setproctitle.setproctitle(resultDir)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    model = vnet.VNet(elu=False, nll=nll)
    batch_size = args.batchSz
    torch.cuda.set_device(
        0
    )  # why do I have to add this line? It seems the below line is useless to apply GPU devices. By Chao.
    model = nn.parallel.DataParallel(model, device_ids=[0])

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        model.apply(weights_init)

    if nll:
        train = train_nll
        test = test_nll
    else:
        train = train_dice
        test = test_dice

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    if args.cuda:
        model = model.cuda()

    if os.path.exists(resultDir):
        shutil.rmtree(resultDir)
    os.makedirs(resultDir, exist_ok=True)

    # transform
    trainTransform = transforms.Compose([transforms.ToTensor()])
    testTransform = transforms.Compose([transforms.ToTensor()])

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

    print("\nloading training set")
    dataManagerTrain = DM.DataManager(params['ModelParams']['dirTrain'],
                                      params['ModelParams']['dirResult'],
                                      params['DataManagerParams'])
    dataManagerTrain.loadTrainingData()  # required
    numpyImages = dataManagerTrain.getNumpyImages()
    numpyGT = dataManagerTrain.getNumpyGT()

    trainSet = promise12.PROMISE12(mode='train',
                                   images=numpyImages,
                                   GT=numpyGT,
                                   transform=trainTransform)
    trainLoader = DataLoader(trainSet,
                             batch_size=batch_size,
                             shuffle=True,
                             **kwargs)

    print("\nloading test set")
    dataManagerTest = DM.DataManager(params['ModelParams']['dirTest'],
                                     params['ModelParams']['dirResult'],
                                     params['DataManagerParams'])
    dataManagerTest.loadTestingData()  # required
    numpyImages = dataManagerTest.getNumpyImages()
    numpyGT = dataManagerTest.getNumpyGT()

    testSet = promise12.PROMISE12(mode='test',
                                  images=numpyImages,
                                  GT=numpyGT,
                                  transform=testTransform)
    testLoader = DataLoader(testSet,
                            batch_size=batch_size,
                            shuffle=True,
                            **kwargs)

    if args.opt == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=1e-1,
                              momentum=0.99,
                              weight_decay=weight_decay)
    elif args.opt == 'adam':
        optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  weight_decay=weight_decay)

    trainF = open(os.path.join(resultDir, 'train.csv'), 'w')
    testF = open(os.path.join(resultDir, 'test.csv'), 'w')

    for epoch in range(1, args.nEpochs + 1):
        adjust_opt(args.opt, optimizer, epoch)
        train(args, epoch, model, trainLoader, optimizer, trainF)
        testDice = test(args, epoch, model, testLoader, optimizer,
                        testF)  # err is accuracy??? by Chao.
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            },
            path=resultDir,
            prefix="vnet")
    os.system('./plot.py {} {} &'.format(len(trainLoader), resultDir))

    trainF.close()
    testF.close()

    # inference, i.e. output predicted mask for test data in .mhd
    if params['ModelParams']['dirInfer'] != '':
        print("loading inference data")
        dataManagerInfer = DM.DataManager(params['ModelParams']['dirInfer'],
                                          params['ModelParams']['dirResult'],
                                          params['DataManagerParams'])
        dataManagerInfer.loadInferData(
        )  # required.  Create .loadInferData??? by Chao.
        numpyImages = dataManagerInfer.getNumpyImages()

        inferSet = promise12.PROMISE12(mode='infer',
                                       images=numpyImages,
                                       GT=None,
                                       transform=testTransform)
        inferLoader = DataLoader(inferSet,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 **kwargs)
        inference(params, args, inferLoader, model)
Пример #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSz', type=int, default=4)
    parser.add_argument('--dice', action='store_true')
    parser.add_argument('--ngpu', type=int, default=1)
    parser.add_argument('--nEpochs', type=int, default=1)
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('-i',
                        '--inference',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='run inference on data set and save results')

    # 1e-8 works well for lung masks but seems to prevent
    # rapid learning for nodule masks
    parser.add_argument('--weight-decay',
                        '--wd',
                        default=1e-8,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-8)')
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--save')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--opt',
                        type=str,
                        default='adam',
                        choices=('sgd', 'adam', 'rmsprop'))
    args = parser.parse_args()
    best_prec1 = 100.
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = args.save or 'work120/vnet.base.{}'.format(datestr())
    nll = True
    if args.dice:
        nll = False
    weight_decay = args.weight_decay
    setproctitle.setproctitle(args.save)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    model = vnet.VNet(elu=False, nll=nll)
    batch_size = args.ngpu * args.batchSz
    gpu_ids = range(args.ngpu)
    if args.cuda:
        model = nn.parallel.DataParallel(model, device_ids=gpu_ids)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        model.apply(weights_init)

    if nll:
        train = train_nll
        test = test_nll
        class_balance = True
    else:
        train = train_dice
        test = test_dice
        class_balance = False

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    print('  + Cuda enabled=', args.cuda)
    if args.cuda:
        model = model.cuda()

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)

    datadir = '/home/TRAINING/datasets/brats/'
    if args.inference != '':
        if not args.resume:
            print("args.resume must be set to do inference")
            exit(1)
        kwargs = {'num_workers': 1} if args.cuda else {}
        src = args.inference
        dst = args.save
        inference_batch_size = args.ngpu
        dataz = np.load(datadir + src)  #'hgg_data120.npz')
        alldata = dataz.f.arr_0
        testdata = alldata[170:]
        dataset = []
        for x in zip(testdata):
            x = np.expand_dims(x, axis=0)
            dataset.append(x)

        loader = DataLoader(dataset,
                            batch_size=inference_batch_size,
                            shuffle=False,
                            collate_fn=noop,
                            **kwargs)
        inference(args, loader, model, trainTransform)
        return

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    print("loading training set")
    dataz = np.load(datadir + 'hgg_data120.npz')
    alldata = dataz.f.arr_0
    lblz = np.load(datadir + 'hgg_labels120.npz')
    alllabels = lblz.f.arr_0
    alllabels[alllabels > 0] = 1  # for now , work only on 1 class
    traindata = alldata[:170]
    testdata = alldata[170:]
    print('number of classes=', np.max(alllabels))
    trainSet = []
    testSet = []
    for x, y in zip(traindata, alllabels[:170]):
        x = np.expand_dims(x, axis=0)
        y = np.expand_dims(y, axis=0)
        trainSet.append((x, y))
    for x, y in zip(testdata, alllabels[170:]):
        x = np.expand_dims(x, axis=0)
        y = np.expand_dims(y, axis=0)
        testSet.append((x, y))

    trainLoader = DataLoader(trainSet,
                             batch_size=batch_size,
                             shuffle=True,
                             **kwargs)
    print("loading test set")

    testLoader = DataLoader(testSet,
                            batch_size=batch_size,
                            shuffle=False,
                            **kwargs)

    bg_weight = 0.011  # I calculated that somewhere else # target_mean / (1. + target_mean)
    fg_weight = 1. - bg_weight
    class_weights = torch.FloatTensor([bg_weight, fg_weight])
    if args.cuda:
        class_weights = class_weights.cuda()

    if args.opt == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=1e-1,
                              momentum=0.99,
                              weight_decay=weight_decay)
    elif args.opt == 'adam':
        optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  weight_decay=weight_decay)

    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')
    err_best = 100.
    for epoch in range(1, args.nEpochs + 1):
        adjust_opt(args.opt, optimizer, epoch)
        train(args, epoch, model, trainLoader, optimizer, trainF,
              class_weights)
        err = test(args, epoch, model, testLoader, optimizer, testF,
                   class_weights)
        is_best = False
        if err < best_prec1:
            is_best = True
            best_prec1 = err
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            }, is_best, args.save, "vnet")
        os.system('./plot.py {} {} &'.format(len(trainLoader), args.save))

    trainF.close()
    testF.close()
Пример #10
0
import vnet
import onnx
import sys
from torch.autograd import Variable
import torch.onnx
import torch as t
import torch.nn as nn
import torchvision
model = vnet.VNet(elu=True, nll=False)
model = nn.parallel.DataParallel(model)
dummy_input = torch.randn(1, 1, 64, 224, 224, device='cuda')
state = t.load('vnet_checkpoint.pth')['state_dict']
model.load_state_dict(state)

model.train(False)
model = model.module.cuda()

torch.onnx.export(model, dummy_input, "vnet.onnx", verbose=True)

model = onnx.load("vnet.onnx")

# Check that the IR is well formed
# onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

# import caffe2.python.onnx.backend as c2
from onnx_caffe2.backend import Caffe2Backend
model_name = 'Vnet'
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(model.graph,