def predict(config, test_on, is_train, fold): if config.model_type not in [ 'VNet', ]: print('ERROR!! model_type should be selected in VNet/') print('Your input for model_type was %s' % config.model_type) return # #train_set = ProbSet(config.train_path) # valid_set = ProbSet(config.valid_path,is_train=False) test_set = ProbSet(config.test_path, is_train=is_train, is_aug=False, return_params=True, test_on=test_on, fold=fold) # print(len(valid_set), len(test_set)) #train_loader = DataLoader(train_set, batch_size=config.batch_size) # valid_loader = DataLoader(valid_set, batch_size=config.batch_size) test_loader = DataLoader(test_set, batch_size=config.batch_size) net = VNet() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net.to(device) # print(config.model_type, net) net.load_state_dict(torch.load(config.net_path)) net.eval() DC = 0. # Dice Coefficient length = 0 iou = 0 for i, (imgs, gts, _, case) in enumerate(test_loader): #path = path[0] # 因为经过了loader被wrap进了元组 又因为batchsize=1 case = case[0] imgs = imgs.to(device) gts = gts.round().long().to(device) outputs = net(imgs) print(gts.cpu().shape, imgs.shape, outputs.shape) # torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 14, 128, 128, 128]) #print(path) ious = IoU( gts.detach().cpu().squeeze().numpy().reshape(-1), outputs.detach().cpu().squeeze().argmax(dim=0).numpy().reshape(-1), num_classes=14) print(ious) print(np.array(ious).mean()) iou += np.array(ious).mean() #print(path) #output_id = path.split('/')[-1] np.save( '/mnt/EXTRA/datasets/competitions/aug/{}/{}/vnet-fold{}-z128-halved-clahe.npy' .format(TEST_ON, case, fold), outputs.detach().cpu().squeeze().numpy()) print(case, outputs.detach().cpu().squeeze().numpy().shape)
class Solver(object): def __init__(self, args, train_loader, val_loader, test_loader): # data loader self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader # models self.net = None self.optimizer = None self.criterion = FocalLoss(alpha=0.8, gamma=0.5) # torch.nn.BCELoss() self.augmentation_prob = args.augmentation_prob # hyper-param self.lr = args.lr self.decayed_lr = args.lr self.beta1 = args.beta1 self.beta2 = args.beta2 # training settings self.num_epochs = args.num_epochs self.num_epochs_decay = args.num_epochs_decay self.batch_size = args.batch_size # step size for logging and val self.log_step = args.log_step self.val_step = args.val_step # path self.best_score = 0.549 self.best_epoch = 0 self.model_path = args.model_path self.csv_path = args.result_path self.model_type = args.model_type self.comment = args.comment self.net_path = os.path.join( self.model_path, '%s-%d-%.7f-%d-%.4f-%s.pkl' % (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay, self.augmentation_prob, self.comment)) ########### TO DO multi GPU setting ########## self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.build_model() def build_model(self): if self.model_type == 'VNet': ###### to do ######## self.net = VNet() self.net.load_state_dict( torch.load( '/mnt/HDD/datasets/competitions/vnet/models_for_cls/VNet-400-0.0001000-200-0.5000-ce-400-200-vnet-dice+ce.pkl' )) self.optimizer = optim.Adam(self.net.parameters(), self.lr, [self.beta1, self.beta2]) self.net.to(self.device) #self.print_network(self.net, self.model_type) def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel( ) # numel() return total num of elems in tensor print(model) print(name) print('the number of parameters: {}'.format(num_params)) # =============================== train =========================# # ===============================================================# def train(self, epoch): self.net.train(True) # Decay learning rate if (epoch + 1) > (self.num_epochs - self.num_epochs_decay): self.decayed_lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = self.decayed_lr print('epoch{}: Decay learning rate to lr: {}.'.format( epoch, self.decayed_lr)) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (imgs, gts) in enumerate(tqdm(self.train_loader)): imgs = imgs.to(self.device) gts = gts.round().long().to(self.device) self.optimizer.zero_grad() outputs = self.net(imgs) # make sure shapes are the same by flattening them # weight = torch.tensor([1.,100.,100.,100.,50.,50.,80.,80.,50.,80.,80.,80.,50.,50.,70.,70.,70.,70., # 60.,60.,100.,100.,100.,]).to(self.device) #ce_loss = nn.CrossEntropyLoss(weight=weight,reduction='mean')(outputs, gts.reshape(-1,128,128,128)) dice_loss = GeneralizedDiceLoss(sigmoid_normalization=False)( outputs, expand_as_one_hot(gts.reshape(-1, 128, 128, 128), 14)) # bce_loss = torch.nn.BCEWithLogitsLoss()(outputs, gts) # focal_loss = FocalLoss(alpha=0.8,gamma=0.5)(outputs, gts) loss = dice_loss #loss = focal_loss + dice_loss epoch_loss += loss.item() * imgs.size( 0) # because reduction = 'mean' loss.backward() self.optimizer.step() # DC += iou(outputs.detach().cpu().squeeze().argmax(dim=1),gts.detach().cpu(),n_classes=14)*imgs.size(0) # length += imgs.size(0) # DC = DC / length # epoch_loss = epoch_loss/length # # Print the log info # print( # 'Epoch [%d/%d], Loss: %.4f, \n[Training] DC: %.4f' % ( # epoch + 1, self.num_epochs, # epoch_loss, # DC)) print('EPOCH{}'.format(epoch)) # =============================== validation ====================# # ===============================================================# @torch.no_grad() def validation(self, epoch): self.net.eval() acc = 0. # Accuracy SE = 0. # Sensit ivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (imgs, gts) in enumerate(self.val_loader): imgs = imgs.to(self.device) gts = gts.round().long().to(self.device) outputs = self.net(imgs) weight = np.array([ 0., 100., 100., 100., 50., 50., 80., 80., 50., 80., 80., 80., 50., 50., 70., 70., 70., 70., 60., 60., 100., 100., 100., ]) ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1), outputs.detach().cpu().squeeze().argmax( dim=0).numpy().reshape(-1), num_classes=14) * imgs.size(0) DC += np.array(ious[1:]).mean() length += imgs.size(0) DC = DC / length score = DC print('[Validation] DC: %.4f' % (DC)) # save the best net model if score > self.best_score: self.best_score = score self.best_epoch = epoch print('Best %s model score: %.4f' % (self.model_type, self.best_score)) torch.save(self.net.state_dict(), self.net_path) # if (1+epoch)%10 == 0 or epoch==0: # torch.save(self.net.state_dict(), self.net_path+'epoch{}.pkl'.format(epoch)) # if (epoch+1)%50 == 0 and epoch!=1: # torch.save(self.net.state_dict(), # '/mnt/HDD/datasets/competitions/vnet/models_for_cls/400-200-dice-epoch{}.pkl'.format(epoch+1)) def test(self): del self.net self.build_model() self.net.load_state_dict(torch.load(self.net_path)) self.net.eval() DC = 0. # Dice Coefficient length = 0 for i, (imgs, gts) in enumerate(self.test_loader): imgs = imgs.to(self.device) gts = gts.round().long().to(self.device) outputs = self.net(imgs) weight = np.array([ 0., 100., 100., 100., 50., 50., 80., 80., 50., 80., 80., 80., 50., 50., 70., 70., 70., 70., 60., 60., 100., 100., 100., ]) ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1), outputs.detach().cpu().squeeze().argmax( dim=0).numpy().reshape(-1), num_classes=14) * imgs.size(0) DC += np.array(ious[1:]).mean() length += imgs.size(0) DC = DC / length score = DC f = open(os.path.join(self.csv_path, 'result.csv'), 'a', encoding='utf8', newline='') wr = csv.writer(f) wr.writerow([ self.model_type, DC, self.lr, self.best_epoch, self.num_epochs, self.num_epochs_decay, self.augmentation_prob, self.batch_size, self.comment ]) f.close() def train_val_test(self): ################# BUG # if os.path.isfile(self.net_path): # #self.net.load_state_dict(torch.load(self.net_path)) # print('saved {} is loaded form: {}'.format(self.model_type, self.net_path)) # else: for epoch in range(self.num_epochs): self.train(epoch) self.validation(epoch) self.test()
def predict(config): if config.model_type not in ['VNet',]: print('ERROR!! model_type should be selected in VNet/') print('Your input for model_type was %s' % config.model_type) return #train_set = ProbSet(config.train_path) valid_set = ProbSet(config.valid_path,is_train=False) test_set = ProbSet(config.test_path,is_train=False,fold=5) # print(len(valid_set), len(test_set)) #train_loader = DataLoader(train_set, batch_size=config.batch_size) valid_loader = DataLoader(valid_set, batch_size=config.batch_size) test_loader = DataLoader(test_set, batch_size=config.batch_size) net = VNet() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net.to(device) print(config.model_type, net) net.load_state_dict(torch.load(config.net_path)) net.eval() DC = 0. # Dice Coefficient length = 0 iou = 0 for i, (imgs, gts) in enumerate(test_loader): #path = path[0] # 因为经过了loader被wrap进了元组 又因为batchsize=1 imgs = imgs.to(device) gts = gts.round().long().to(device) outputs = net(imgs) print(gts.cpu().shape, imgs.shape, outputs.shape) # torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 1, 128, 128, 128]) torch.Size([1, 14, 128, 128, 128]) #print(path) ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1), outputs.detach().cpu().squeeze().argmax(dim=0).numpy().reshape(-1), num_classes=14) print(ious) print(np.array(ious).mean()) iou += np.array(ious).mean() #print(path) #output_id = path.split('/')[-1] #np.save('/mnt/HDD/datasets/competitions/vnet/output/fold1/output{}.npy'.format(output_id), outputs.detach().cpu().squeeze().numpy()) for j in range(70,128): plt.figure() plt.subplot(2,2,1) # plt.imshow(np.array(imgs.cpu().squeeze()[j,0])) plt.imshow(np.array(imgs.cpu().squeeze()[j])) plt.colorbar() plt.subplot(2, 2, 2) plt.title(np.unique(np.array(gts.cpu().detach().numpy().squeeze()[j]))) plt.imshow(np.array(gts.cpu().detach().numpy().squeeze()[j])) plt.colorbar() plt.subplot(2, 2, 3) plt.title(np.unique(outputs.cpu().detach().numpy().squeeze().argmax(axis=0)[j])) plt.imshow(outputs.cpu().detach().numpy().squeeze().argmax(axis=0)[j].reshape(128,128)) #plt.imshow(outputs.cpu().detach().numpy().squeeze()[8,j].reshape(128, 128)) plt.colorbar() plt.show() time.sleep(2) print('######', iou/10)