Ejemplo n.º 1
0
def predict(config, test_on, is_train, fold):
    if config.model_type not in [
            'VNet',
    ]:
        print('ERROR!! model_type should be selected in VNet/')
        print('Your input for model_type was %s' % config.model_type)
        return

    # #train_set = ProbSet(config.train_path)
    # valid_set = ProbSet(config.valid_path,is_train=False)
    test_set = ProbSet(config.test_path,
                       is_train=is_train,
                       is_aug=False,
                       return_params=True,
                       test_on=test_on,
                       fold=fold)
    # print(len(valid_set), len(test_set))
    #train_loader = DataLoader(train_set, batch_size=config.batch_size)
    # valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
    test_loader = DataLoader(test_set, batch_size=config.batch_size)

    net = VNet()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    # print(config.model_type, net)

    net.load_state_dict(torch.load(config.net_path))
    net.eval()

    DC = 0.  # Dice Coefficient
    length = 0
    iou = 0
    for i, (imgs, gts, _, case) in enumerate(test_loader):

        #path = path[0] # 因为经过了loader被wrap进了元组 又因为batchsize=1
        case = case[0]

        imgs = imgs.to(device)
        gts = gts.round().long().to(device)

        outputs = net(imgs)
        print(gts.cpu().shape, imgs.shape, outputs.shape)
        # torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 14, 128, 128, 128])
        #print(path)
        ious = IoU(
            gts.detach().cpu().squeeze().numpy().reshape(-1),
            outputs.detach().cpu().squeeze().argmax(dim=0).numpy().reshape(-1),
            num_classes=14)
        print(ious)
        print(np.array(ious).mean())
        iou += np.array(ious).mean()
        #print(path)
        #output_id = path.split('/')[-1]
        np.save(
            '/mnt/EXTRA/datasets/competitions/aug/{}/{}/vnet-fold{}-z128-halved-clahe.npy'
            .format(TEST_ON, case, fold),
            outputs.detach().cpu().squeeze().numpy())
        print(case, outputs.detach().cpu().squeeze().numpy().shape)
Ejemplo n.º 2
0
    def build_model(self):
        if self.model_type == 'VNet':
            ###### to do ########
            self.net = VNet()

        self.optimizer = optim.Adam(self.net.parameters(), self.lr,
                                    [self.beta1, self.beta2])
        self.net.to(self.device)
Ejemplo n.º 3
0
    def build_model(self):
        if self.model_type == 'VNet':
            ###### to do ########
            self.net = VNet()
            self.net.load_state_dict(
                torch.load(
                    '/mnt/HDD/datasets/competitions/vnet/models_for_cls/VNet-400-0.0001000-200-0.5000-ce-400-200-vnet-dice+ce.pkl'
                ))

        self.optimizer = optim.Adam(self.net.parameters(), self.lr,
                                    [self.beta1, self.beta2])
        self.net.to(self.device)
Ejemplo n.º 4
0
def main():

    args = parse_args()
    args.pretrain = False

    root_path = 'exps/exp_{}'.format(args.exp)

    if not os.path.exists(root_path):
        os.mkdir(root_path)
        os.mkdir(os.path.join(root_path, "log"))
        os.mkdir(os.path.join(root_path, "model"))

    base_lr = args.lr  # base learning rate

    train_dataset, val_dataset = build_dataset(args.dataset, args.data_root,
                                               args.train_list)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    model = VNet(args.n_channels, args.n_classes).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7)

    model = torch.nn.DataParallel(model)

    model.train()

    if args.resume is None:
        assert os.path.exists(args.load_path)
        state_dict = model.state_dict()
        print("Loading weights...")
        pretrain_state_dict = torch.load(args.load_path,
                                         map_location="cpu")['state_dict']

        for k in list(pretrain_state_dict.keys()):
            if k not in state_dict:
                del pretrain_state_dict[k]
        model.load_state_dict(pretrain_state_dict)
        print("Loaded weights")
    else:
        print("Resuming from {}".format(args.resume))
        checkpoint = torch.load(args.resume, map_location="cpu")

        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.load_state_dict(checkpoint['state_dict'])

    logger = Logger(root_path)
    saver = Saver(root_path)

    for epoch in range(args.start_epoch, args.epochs):
        train(model, train_loader, optimizer, logger, args, epoch)
        validate(model, val_loader, optimizer, logger, saver, args, epoch)
        adjust_learning_rate(args, optimizer, epoch)
Ejemplo n.º 5
0
                        shuffle=False,
                        drop_last=False)
    test_loader = Loader(test_path,
                         1,
                         torch_type=arg.dtype,
                         cpus=arg.cpus,
                         shuffle=False,
                         drop_last=False)
    norm_layer = nn.BatchNorm2d

    act = nn.ReLU
    if arg.model == "unet":
        net = Unet2D(feature_scale=arg.feature_scale, act=act)
    elif arg.model == "unetres":
        net = UnetRes2D(1, nn.InstanceNorm2d, is_pool=arg.pool)
    elif arg.model == "unetbr":
        net = UnetBR2D(1, nn.InstanceNorm2d, is_pool=arg.pool)
    elif arg.model == "vnet":
        net = VNet(elu=False)

    net = nn.DataParallel(net).to(torch_device)
    recon_loss = nn.BCEWithLogitsLoss()

    if arg.model == "vnet":
        model = VNetTrainer(arg, net, torch_device, recon_loss=recon_loss)
    else:
        model = CNNTrainer(arg, net, torch_device, recon_loss=recon_loss)
    if arg.test is False:
        model.train(train_loader, val_loader)
    model.test(test_loader, val_loader)
Ejemplo n.º 6
0
if setup.GPUS:
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(setup.GPUS)))
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# set paths to scripts, templates, weights etc.
TEMPLATE_PATH = os.path.join('templates',
                             'scct_unsmooth_SS_0.01_128x128x128.nii.gz')

if setup.weights:
    WEIGHT_PATH = setup.weights
else:
    WEIGHT_PATH = 'weights'

# load the model and weights
model = VNet()
model.load_weights(WEIGHT_PATH)

# setup directory trees
IN_DIR = setup.IN_DIR
OUT_DIR = setup.OUT_DIR
if not os.path.exists(IN_DIR):
    os.mkdir(IN_DIR)
if not os.path.exists(OUT_DIR):
    os.mkdir(OUT_DIR)

# load input data
files = sorted(next(os.walk(IN_DIR))[2])
files = [os.path.join(IN_DIR, f) for f in files]
template = ants.image_read(TEMPLATE_PATH, pixeltype='float')
Ejemplo n.º 7
0
classes = ['Pancreas']
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
print('processing dataset {}'.format(args.dataset))

root_dir, list_path = get_test_paths(args.dataset, args.data_root)

stride = 20
n_class = 2

votesave_path = os.path.join(args.votemap)
os.makedirs(votesave_path, exist_ok=True)

patch_size = 64

if __name__ == "__main__":
    net = VNet(args.n_channels, args.n_classes).cuda()
    #net = torch.load(snapshot_path)

    dices = []
    dice_for_cases = []
    case_list = []
    sys.stdout.flush()

    # read the list path from the cross validation
    image_list = open(list_path).readlines()
    assert os.path.exists(args.load_path)
    state_dict = torch.load(args.load_path, map_location="cpu")['state_dict']
    new_state_dict = OrderedDict()
    for key in state_dict.keys():
        new_state_dict[key[7:]] = state_dict[key]
Ejemplo n.º 8
0
                                           args.data_root,
                                           args.train_list,
                                           sampling=args.sampling)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           pin_memory=True,
                                           drop_last=True)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=1,
                                         num_workers=args.num_workers,
                                         pin_memory=True)
model = VNet(args.n_channels, args.n_classes, input_size=64,
             pretrain=True).cuda()
model_ema = VNet(args.n_channels, args.n_classes, input_size=64,
                 pretrain=True).cuda()

optimizer = torch.optim.SGD(model.parameters(),
                            lr=args.lr,
                            momentum=0.9,
                            weight_decay=0.0005)
model = torch.nn.DataParallel(model)
model_ema = torch.nn.DataParallel(model_ema)
model_ema.load_state_dict(model.state_dict())
print("Model Initialized")
logger = Logger(root_path)
saver = Saver(root_path, save_freq=args.save_freq)
if args.sampling == 'default':
    contrast = RGBMoCo(128, K=4096, T=args.temperature).cuda()
Ejemplo n.º 9
0
def predict(config):
    if config.model_type not in ['VNet',]:
        print('ERROR!! model_type should be selected in VNet/')
        print('Your input for model_type was %s' % config.model_type)
        return

    #train_set = ProbSet(config.train_path)
    valid_set = ProbSet(config.valid_path,is_train=False)
    test_set = ProbSet(config.test_path,is_train=False,fold=5)
    # print(len(valid_set), len(test_set))
    #train_loader = DataLoader(train_set, batch_size=config.batch_size)
    valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
    test_loader = DataLoader(test_set, batch_size=config.batch_size)


    net = VNet()


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    print(config.model_type, net)

    net.load_state_dict(torch.load(config.net_path))
    net.eval()


    DC = 0.  # Dice Coefficient
    length = 0
    iou = 0
    for i, (imgs, gts) in enumerate(test_loader):

        #path = path[0] # 因为经过了loader被wrap进了元组 又因为batchsize=1

        imgs = imgs.to(device)
        gts = gts.round().long().to(device)

        outputs = net(imgs)
        print(gts.cpu().shape, imgs.shape, outputs.shape)
        # torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 14, 128, 128, 128])
        #print(path)
        ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1),
                   outputs.detach().cpu().squeeze().argmax(dim=0).numpy().reshape(-1), num_classes=14)
        print(ious)
        print(np.array(ious).mean())
        iou += np.array(ious).mean()
        #print(path)
        #output_id = path.split('/')[-1]
        #np.save('/mnt/HDD/datasets/competitions/vnet/output/fold1/output{}.npy'.format(output_id), outputs.detach().cpu().squeeze().numpy())

        for j in range(70,128):
            plt.figure()
            plt.subplot(2,2,1)
            # plt.imshow(np.array(imgs.cpu().squeeze()[j,0]))
            plt.imshow(np.array(imgs.cpu().squeeze()[j]))
            plt.colorbar()
            plt.subplot(2, 2, 2)
            plt.title(np.unique(np.array(gts.cpu().detach().numpy().squeeze()[j])))
            plt.imshow(np.array(gts.cpu().detach().numpy().squeeze()[j]))
            plt.colorbar()
            plt.subplot(2, 2, 3)
            plt.title(np.unique(outputs.cpu().detach().numpy().squeeze().argmax(axis=0)[j]))
            plt.imshow(outputs.cpu().detach().numpy().squeeze().argmax(axis=0)[j].reshape(128,128))
            #plt.imshow(outputs.cpu().detach().numpy().squeeze()[8,j].reshape(128, 128))
            plt.colorbar()
            plt.show()
            time.sleep(2)

    print('######', iou/10)
Ejemplo n.º 10
0
def main():

    args = parse_args()
    args.pretrain = False
    print("Using GPU: {}".format(args.local_rank))
    root_path = 'exps/exp_{}'.format(args.exp)
    if args.local_rank == 0 and not os.path.exists(root_path):
        os.mkdir(root_path)
        os.mkdir(os.path.join(root_path, "log"))
        os.mkdir(os.path.join(root_path, "model"))

    base_lr = args.lr  # base learning rate
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    train_dataset, val_dataset = build_dataset(args.dataset, args.data_root,
                                               args.train_list)
    args.world_size = len(args.gpu.split(","))
    if args.world_size > 1:
        os.environ['MASTER_PORT'] = args.port
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group('nccl')
        device = torch.device('cuda:{}'.format(args.local_rank))
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=len(args.gpu.split(",")),
            rank=args.local_rank)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               sampler=train_sampler,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    model = VNet(args.n_channels, args.n_classes).cuda(args.local_rank)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7)
    if args.world_size > 1:
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)

    model.train()
    print("Loaded weights")

    logger = Logger(root_path)
    saver = Saver(root_path)

    for epoch in range(args.start_epoch, args.epochs):
        train(model, train_loader, optimizer, logger, args, epoch)
        validate(model, val_loader, optimizer, logger, saver, args, epoch)
        adjust_learning_rate(args, optimizer, epoch)