Exemplo n.º 1
0
def predict(config, test_on, is_train, fold):
    if config.model_type not in [
            'VNet',
    ]:
        print('ERROR!! model_type should be selected in VNet/')
        print('Your input for model_type was %s' % config.model_type)
        return

    # #train_set = ProbSet(config.train_path)
    # valid_set = ProbSet(config.valid_path,is_train=False)
    test_set = ProbSet(config.test_path,
                       is_train=is_train,
                       is_aug=False,
                       return_params=True,
                       test_on=test_on,
                       fold=fold)
    # print(len(valid_set), len(test_set))
    #train_loader = DataLoader(train_set, batch_size=config.batch_size)
    # valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
    test_loader = DataLoader(test_set, batch_size=config.batch_size)

    net = VNet()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    # print(config.model_type, net)

    net.load_state_dict(torch.load(config.net_path))
    net.eval()

    DC = 0.  # Dice Coefficient
    length = 0
    iou = 0
    for i, (imgs, gts, _, case) in enumerate(test_loader):

        #path = path[0] # 因为经过了loader被wrap进了元组 又因为batchsize=1
        case = case[0]

        imgs = imgs.to(device)
        gts = gts.round().long().to(device)

        outputs = net(imgs)
        print(gts.cpu().shape, imgs.shape, outputs.shape)
        # torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 14, 128, 128, 128])
        #print(path)
        ious = IoU(
            gts.detach().cpu().squeeze().numpy().reshape(-1),
            outputs.detach().cpu().squeeze().argmax(dim=0).numpy().reshape(-1),
            num_classes=14)
        print(ious)
        print(np.array(ious).mean())
        iou += np.array(ious).mean()
        #print(path)
        #output_id = path.split('/')[-1]
        np.save(
            '/mnt/EXTRA/datasets/competitions/aug/{}/{}/vnet-fold{}-z128-halved-clahe.npy'
            .format(TEST_ON, case, fold),
            outputs.detach().cpu().squeeze().numpy())
        print(case, outputs.detach().cpu().squeeze().numpy().shape)
Exemplo n.º 2
0
def main():

    args = parse_args()
    args.pretrain = False

    root_path = 'exps/exp_{}'.format(args.exp)

    if not os.path.exists(root_path):
        os.mkdir(root_path)
        os.mkdir(os.path.join(root_path, "log"))
        os.mkdir(os.path.join(root_path, "model"))

    base_lr = args.lr  # base learning rate

    train_dataset, val_dataset = build_dataset(args.dataset, args.data_root,
                                               args.train_list)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    model = VNet(args.n_channels, args.n_classes).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.0005)
    #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7)

    model = torch.nn.DataParallel(model)

    model.train()

    if args.resume is None:
        assert os.path.exists(args.load_path)
        state_dict = model.state_dict()
        print("Loading weights...")
        pretrain_state_dict = torch.load(args.load_path,
                                         map_location="cpu")['state_dict']

        for k in list(pretrain_state_dict.keys()):
            if k not in state_dict:
                del pretrain_state_dict[k]
        model.load_state_dict(pretrain_state_dict)
        print("Loaded weights")
    else:
        print("Resuming from {}".format(args.resume))
        checkpoint = torch.load(args.resume, map_location="cpu")

        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.load_state_dict(checkpoint['state_dict'])

    logger = Logger(root_path)
    saver = Saver(root_path)

    for epoch in range(args.start_epoch, args.epochs):
        train(model, train_loader, optimizer, logger, args, epoch)
        validate(model, val_loader, optimizer, logger, saver, args, epoch)
        adjust_learning_rate(args, optimizer, epoch)
Exemplo n.º 3
0
    # read the list path from the cross validation
    image_list = open(list_path).readlines()
    assert os.path.exists(args.load_path)
    state_dict = torch.load(args.load_path, map_location="cpu")['state_dict']
    new_state_dict = OrderedDict()
    for key in state_dict.keys():
        new_state_dict[key[7:]] = state_dict[key]

    state_dict = net.state_dict()
    print("Loading weights...")
    for k in list(new_state_dict.keys()):
        if k not in state_dict:
            del new_state_dict[k]

    state_dict.update(new_state_dict)
    net.load_state_dict(state_dict)
    net.cuda()
    net.eval()

    # test passed for the first case
    for i in range(0, len(image_list)):
        file_name = image_list[i].strip('\n')
        if '/' in file_name:
            file_name = os.path.basename(file_name)
        case_list.append(file_name)
        imglabelpath = os.path.join(root_dir, file_name)
        image, label = load_data_test(imglabelpath,
                                      dataset=args.dataset,
                                      convert_msd=False)
        print(label.shape)
        map_name = os.path.join(votesave_path, file_name + '.npz')
Exemplo n.º 4
0
class Solver(object):
    def __init__(self, args, train_loader, val_loader, test_loader):
        # data loader
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

        # models
        self.net = None
        self.optimizer = None

        self.criterion = FocalLoss(alpha=0.8, gamma=0.5)  # torch.nn.BCELoss()
        self.augmentation_prob = args.augmentation_prob

        # hyper-param
        self.lr = args.lr
        self.decayed_lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2

        # training settings
        self.num_epochs = args.num_epochs
        self.num_epochs_decay = args.num_epochs_decay
        self.batch_size = args.batch_size

        # step size for logging and val
        self.log_step = args.log_step
        self.val_step = args.val_step

        # path
        self.best_score = 0.549
        self.best_epoch = 0
        self.model_path = args.model_path
        self.csv_path = args.result_path
        self.model_type = args.model_type

        self.comment = args.comment

        self.net_path = os.path.join(
            self.model_path, '%s-%d-%.7f-%d-%.4f-%s.pkl' %
            (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay,
             self.augmentation_prob, self.comment))

        ########### TO DO multi GPU setting ##########
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.build_model()

    def build_model(self):
        if self.model_type == 'VNet':
            ###### to do ########
            self.net = VNet()
            self.net.load_state_dict(
                torch.load(
                    '/mnt/HDD/datasets/competitions/vnet/models_for_cls/VNet-400-0.0001000-200-0.5000-ce-400-200-vnet-dice+ce.pkl'
                ))

        self.optimizer = optim.Adam(self.net.parameters(), self.lr,
                                    [self.beta1, self.beta2])
        self.net.to(self.device)

        #self.print_network(self.net, self.model_type)

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel(
            )  # numel() return total num of elems in tensor
        print(model)
        print(name)
        print('the number of parameters: {}'.format(num_params))

    # =============================== train =========================#
    # ===============================================================#
    def train(self, epoch):
        self.net.train(True)

        # Decay learning rate
        if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
            self.decayed_lr -= (self.lr / float(self.num_epochs_decay))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.decayed_lr
            print('epoch{}: Decay learning rate to lr: {}.'.format(
                epoch, self.decayed_lr))

        epoch_loss = 0

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0

        for i, (imgs, gts) in enumerate(tqdm(self.train_loader)):
            imgs = imgs.to(self.device)
            gts = gts.round().long().to(self.device)

            self.optimizer.zero_grad()

            outputs = self.net(imgs)

            # make sure shapes are the same by flattening them

            # weight = torch.tensor([1.,100.,100.,100.,50.,50.,80.,80.,50.,80.,80.,80.,50.,50.,70.,70.,70.,70.,
            #                        60.,60.,100.,100.,100.,]).to(self.device)

            #ce_loss = nn.CrossEntropyLoss(weight=weight,reduction='mean')(outputs, gts.reshape(-1,128,128,128))
            dice_loss = GeneralizedDiceLoss(sigmoid_normalization=False)(
                outputs, expand_as_one_hot(gts.reshape(-1, 128, 128, 128), 14))
            # bce_loss = torch.nn.BCEWithLogitsLoss()(outputs, gts)
            # focal_loss = FocalLoss(alpha=0.8,gamma=0.5)(outputs, gts)

            loss = dice_loss
            #loss = focal_loss + dice_loss
            epoch_loss += loss.item() * imgs.size(
                0)  # because reduction = 'mean'
            loss.backward()
            self.optimizer.step()

            # DC += iou(outputs.detach().cpu().squeeze().argmax(dim=1),gts.detach().cpu(),n_classes=14)*imgs.size(0)
            # length += imgs.size(0)

        # DC = DC / length
        # epoch_loss = epoch_loss/length
        # # Print the log info
        # print(
        #     'Epoch [%d/%d], Loss: %.4f, \n[Training] DC: %.4f' % (
        #         epoch + 1, self.num_epochs,
        #         epoch_loss,
        #          DC))
        print('EPOCH{}'.format(epoch))

    # =============================== validation ====================#
    # ===============================================================#
    @torch.no_grad()
    def validation(self, epoch):
        self.net.eval()

        acc = 0.  # Accuracy
        SE = 0.  # Sensit ivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0

        for i, (imgs, gts) in enumerate(self.val_loader):
            imgs = imgs.to(self.device)
            gts = gts.round().long().to(self.device)

            outputs = self.net(imgs)

            weight = np.array([
                0.,
                100.,
                100.,
                100.,
                50.,
                50.,
                80.,
                80.,
                50.,
                80.,
                80.,
                80.,
                50.,
                50.,
                70.,
                70.,
                70.,
                70.,
                60.,
                60.,
                100.,
                100.,
                100.,
            ])
            ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1),
                       outputs.detach().cpu().squeeze().argmax(
                           dim=0).numpy().reshape(-1),
                       num_classes=14) * imgs.size(0)
            DC += np.array(ious[1:]).mean()
            length += imgs.size(0)

        DC = DC / length

        score = DC

        print('[Validation] DC: %.4f' % (DC))

        # save the best net model
        if score > self.best_score:
            self.best_score = score
            self.best_epoch = epoch
            print('Best %s model score: %.4f' %
                  (self.model_type, self.best_score))
            torch.save(self.net.state_dict(), self.net_path)
        # if (1+epoch)%10 == 0 or epoch==0:
        #     torch.save(self.net.state_dict(), self.net_path+'epoch{}.pkl'.format(epoch))
        # if (epoch+1)%50 == 0 and epoch!=1:
        #     torch.save(self.net.state_dict(),
        #                '/mnt/HDD/datasets/competitions/vnet/models_for_cls/400-200-dice-epoch{}.pkl'.format(epoch+1))

    def test(self):
        del self.net
        self.build_model()
        self.net.load_state_dict(torch.load(self.net_path))

        self.net.eval()

        DC = 0.  # Dice Coefficient
        length = 0

        for i, (imgs, gts) in enumerate(self.test_loader):
            imgs = imgs.to(self.device)
            gts = gts.round().long().to(self.device)

            outputs = self.net(imgs)

            weight = np.array([
                0.,
                100.,
                100.,
                100.,
                50.,
                50.,
                80.,
                80.,
                50.,
                80.,
                80.,
                80.,
                50.,
                50.,
                70.,
                70.,
                70.,
                70.,
                60.,
                60.,
                100.,
                100.,
                100.,
            ])
            ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1),
                       outputs.detach().cpu().squeeze().argmax(
                           dim=0).numpy().reshape(-1),
                       num_classes=14) * imgs.size(0)
            DC += np.array(ious[1:]).mean()
            length += imgs.size(0)

        DC = DC / length
        score = DC

        f = open(os.path.join(self.csv_path, 'result.csv'),
                 'a',
                 encoding='utf8',
                 newline='')
        wr = csv.writer(f)
        wr.writerow([
            self.model_type, DC, self.lr, self.best_epoch, self.num_epochs,
            self.num_epochs_decay, self.augmentation_prob, self.batch_size,
            self.comment
        ])
        f.close()

    def train_val_test(self):

        ################# BUG
        # if os.path.isfile(self.net_path):
        #     #self.net.load_state_dict(torch.load(self.net_path))
        #     print('saved {} is loaded form: {}'.format(self.model_type, self.net_path))
        # else:
        for epoch in range(self.num_epochs):
            self.train(epoch)
            self.validation(epoch)

        self.test()
Exemplo n.º 5
0
val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=1,
                                         num_workers=args.num_workers,
                                         pin_memory=True)
model = VNet(args.n_channels, args.n_classes, input_size=64,
             pretrain=True).cuda()
model_ema = VNet(args.n_channels, args.n_classes, input_size=64,
                 pretrain=True).cuda()

optimizer = torch.optim.SGD(model.parameters(),
                            lr=args.lr,
                            momentum=0.9,
                            weight_decay=0.0005)
model = torch.nn.DataParallel(model)
model_ema = torch.nn.DataParallel(model_ema)
model_ema.load_state_dict(model.state_dict())
print("Model Initialized")
logger = Logger(root_path)
saver = Saver(root_path, save_freq=args.save_freq)
if args.sampling == 'default':
    contrast = RGBMoCo(128, K=4096, T=args.temperature).cuda()
elif args.sampling == 'layerwise':
    contrast = RGBMoCoNew(128, K=4096, T=args.temperature).cuda()
else:
    raise ValueError("unsupported sampling method")
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(args.start_epoch, args.epochs):
    pretrain(model, model_ema, train_loader, optimizer, logger, saver, args,
             epoch, contrast, criterion)
    adjust_learning_rate(args, optimizer, epoch)
Exemplo n.º 6
0
def predict(config):
    if config.model_type not in ['VNet',]:
        print('ERROR!! model_type should be selected in VNet/')
        print('Your input for model_type was %s' % config.model_type)
        return

    #train_set = ProbSet(config.train_path)
    valid_set = ProbSet(config.valid_path,is_train=False)
    test_set = ProbSet(config.test_path,is_train=False,fold=5)
    # print(len(valid_set), len(test_set))
    #train_loader = DataLoader(train_set, batch_size=config.batch_size)
    valid_loader = DataLoader(valid_set, batch_size=config.batch_size)
    test_loader = DataLoader(test_set, batch_size=config.batch_size)


    net = VNet()


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    print(config.model_type, net)

    net.load_state_dict(torch.load(config.net_path))
    net.eval()


    DC = 0.  # Dice Coefficient
    length = 0
    iou = 0
    for i, (imgs, gts) in enumerate(test_loader):

        #path = path[0] # 因为经过了loader被wrap进了元组 又因为batchsize=1

        imgs = imgs.to(device)
        gts = gts.round().long().to(device)

        outputs = net(imgs)
        print(gts.cpu().shape, imgs.shape, outputs.shape)
        # torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 14, 128, 128, 128])
        #print(path)
        ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1),
                   outputs.detach().cpu().squeeze().argmax(dim=0).numpy().reshape(-1), num_classes=14)
        print(ious)
        print(np.array(ious).mean())
        iou += np.array(ious).mean()
        #print(path)
        #output_id = path.split('/')[-1]
        #np.save('/mnt/HDD/datasets/competitions/vnet/output/fold1/output{}.npy'.format(output_id), outputs.detach().cpu().squeeze().numpy())

        for j in range(70,128):
            plt.figure()
            plt.subplot(2,2,1)
            # plt.imshow(np.array(imgs.cpu().squeeze()[j,0]))
            plt.imshow(np.array(imgs.cpu().squeeze()[j]))
            plt.colorbar()
            plt.subplot(2, 2, 2)
            plt.title(np.unique(np.array(gts.cpu().detach().numpy().squeeze()[j])))
            plt.imshow(np.array(gts.cpu().detach().numpy().squeeze()[j]))
            plt.colorbar()
            plt.subplot(2, 2, 3)
            plt.title(np.unique(outputs.cpu().detach().numpy().squeeze().argmax(axis=0)[j]))
            plt.imshow(outputs.cpu().detach().numpy().squeeze().argmax(axis=0)[j].reshape(128,128))
            #plt.imshow(outputs.cpu().detach().numpy().squeeze()[8,j].reshape(128, 128))
            plt.colorbar()
            plt.show()
            time.sleep(2)

    print('######', iou/10)