def main(): args = parse_args() args.pretrain = False root_path = 'exps/exp_{}'.format(args.exp) if not os.path.exists(root_path): os.mkdir(root_path) os.mkdir(os.path.join(root_path, "log")) os.mkdir(os.path.join(root_path, "model")) base_lr = args.lr # base learning rate train_dataset, val_dataset = build_dataset(args.dataset, args.data_root, args.train_list) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True) model = VNet(args.n_channels, args.n_classes).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0005) #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7) model = torch.nn.DataParallel(model) model.train() if args.resume is None: assert os.path.exists(args.load_path) state_dict = model.state_dict() print("Loading weights...") pretrain_state_dict = torch.load(args.load_path, map_location="cpu")['state_dict'] for k in list(pretrain_state_dict.keys()): if k not in state_dict: del pretrain_state_dict[k] model.load_state_dict(pretrain_state_dict) print("Loaded weights") else: print("Resuming from {}".format(args.resume)) checkpoint = torch.load(args.resume, map_location="cpu") optimizer.load_state_dict(checkpoint['optimizer_state_dict']) model.load_state_dict(checkpoint['state_dict']) logger = Logger(root_path) saver = Saver(root_path) for epoch in range(args.start_epoch, args.epochs): train(model, train_loader, optimizer, logger, args, epoch) validate(model, val_loader, optimizer, logger, saver, args, epoch) adjust_learning_rate(args, optimizer, epoch)
class Solver(object): def __init__(self, args, train_loader, val_loader, test_loader): # data loader self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader # models self.net = None self.optimizer = None self.criterion = FocalLoss(alpha=0.8, gamma=0.5) # torch.nn.BCELoss() self.augmentation_prob = args.augmentation_prob # hyper-param self.lr = args.lr self.decayed_lr = args.lr self.beta1 = args.beta1 self.beta2 = args.beta2 # training settings self.num_epochs = args.num_epochs self.num_epochs_decay = args.num_epochs_decay self.batch_size = args.batch_size # step size for logging and val self.log_step = args.log_step self.val_step = args.val_step # path self.best_score = 0.549 self.best_epoch = 0 self.model_path = args.model_path self.csv_path = args.result_path self.model_type = args.model_type self.comment = args.comment self.net_path = os.path.join( self.model_path, '%s-%d-%.7f-%d-%.4f-%s.pkl' % (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay, self.augmentation_prob, self.comment)) ########### TO DO multi GPU setting ########## self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.build_model() def build_model(self): if self.model_type == 'VNet': ###### to do ######## self.net = VNet() self.net.load_state_dict( torch.load( '/mnt/HDD/datasets/competitions/vnet/models_for_cls/VNet-400-0.0001000-200-0.5000-ce-400-200-vnet-dice+ce.pkl' )) self.optimizer = optim.Adam(self.net.parameters(), self.lr, [self.beta1, self.beta2]) self.net.to(self.device) #self.print_network(self.net, self.model_type) def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel( ) # numel() return total num of elems in tensor print(model) print(name) print('the number of parameters: {}'.format(num_params)) # =============================== train =========================# # ===============================================================# def train(self, epoch): self.net.train(True) # Decay learning rate if (epoch + 1) > (self.num_epochs - self.num_epochs_decay): self.decayed_lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = self.decayed_lr print('epoch{}: Decay learning rate to lr: {}.'.format( epoch, self.decayed_lr)) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (imgs, gts) in enumerate(tqdm(self.train_loader)): imgs = imgs.to(self.device) gts = gts.round().long().to(self.device) self.optimizer.zero_grad() outputs = self.net(imgs) # make sure shapes are the same by flattening them # weight = torch.tensor([1.,100.,100.,100.,50.,50.,80.,80.,50.,80.,80.,80.,50.,50.,70.,70.,70.,70., # 60.,60.,100.,100.,100.,]).to(self.device) #ce_loss = nn.CrossEntropyLoss(weight=weight,reduction='mean')(outputs, gts.reshape(-1,128,128,128)) dice_loss = GeneralizedDiceLoss(sigmoid_normalization=False)( outputs, expand_as_one_hot(gts.reshape(-1, 128, 128, 128), 14)) # bce_loss = torch.nn.BCEWithLogitsLoss()(outputs, gts) # focal_loss = FocalLoss(alpha=0.8,gamma=0.5)(outputs, gts) loss = dice_loss #loss = focal_loss + dice_loss epoch_loss += loss.item() * imgs.size( 0) # because reduction = 'mean' loss.backward() self.optimizer.step() # DC += iou(outputs.detach().cpu().squeeze().argmax(dim=1),gts.detach().cpu(),n_classes=14)*imgs.size(0) # length += imgs.size(0) # DC = DC / length # epoch_loss = epoch_loss/length # # Print the log info # print( # 'Epoch [%d/%d], Loss: %.4f, \n[Training] DC: %.4f' % ( # epoch + 1, self.num_epochs, # epoch_loss, # DC)) print('EPOCH{}'.format(epoch)) # =============================== validation ====================# # ===============================================================# @torch.no_grad() def validation(self, epoch): self.net.eval() acc = 0. # Accuracy SE = 0. # Sensit ivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (imgs, gts) in enumerate(self.val_loader): imgs = imgs.to(self.device) gts = gts.round().long().to(self.device) outputs = self.net(imgs) weight = np.array([ 0., 100., 100., 100., 50., 50., 80., 80., 50., 80., 80., 80., 50., 50., 70., 70., 70., 70., 60., 60., 100., 100., 100., ]) ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1), outputs.detach().cpu().squeeze().argmax( dim=0).numpy().reshape(-1), num_classes=14) * imgs.size(0) DC += np.array(ious[1:]).mean() length += imgs.size(0) DC = DC / length score = DC print('[Validation] DC: %.4f' % (DC)) # save the best net model if score > self.best_score: self.best_score = score self.best_epoch = epoch print('Best %s model score: %.4f' % (self.model_type, self.best_score)) torch.save(self.net.state_dict(), self.net_path) # if (1+epoch)%10 == 0 or epoch==0: # torch.save(self.net.state_dict(), self.net_path+'epoch{}.pkl'.format(epoch)) # if (epoch+1)%50 == 0 and epoch!=1: # torch.save(self.net.state_dict(), # '/mnt/HDD/datasets/competitions/vnet/models_for_cls/400-200-dice-epoch{}.pkl'.format(epoch+1)) def test(self): del self.net self.build_model() self.net.load_state_dict(torch.load(self.net_path)) self.net.eval() DC = 0. # Dice Coefficient length = 0 for i, (imgs, gts) in enumerate(self.test_loader): imgs = imgs.to(self.device) gts = gts.round().long().to(self.device) outputs = self.net(imgs) weight = np.array([ 0., 100., 100., 100., 50., 50., 80., 80., 50., 80., 80., 80., 50., 50., 70., 70., 70., 70., 60., 60., 100., 100., 100., ]) ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1), outputs.detach().cpu().squeeze().argmax( dim=0).numpy().reshape(-1), num_classes=14) * imgs.size(0) DC += np.array(ious[1:]).mean() length += imgs.size(0) DC = DC / length score = DC f = open(os.path.join(self.csv_path, 'result.csv'), 'a', encoding='utf8', newline='') wr = csv.writer(f) wr.writerow([ self.model_type, DC, self.lr, self.best_epoch, self.num_epochs, self.num_epochs_decay, self.augmentation_prob, self.batch_size, self.comment ]) f.close() def train_val_test(self): ################# BUG # if os.path.isfile(self.net_path): # #self.net.load_state_dict(torch.load(self.net_path)) # print('saved {} is loaded form: {}'.format(self.model_type, self.net_path)) # else: for epoch in range(self.num_epochs): self.train(epoch) self.validation(epoch) self.test()
#net = torch.load(snapshot_path) dices = [] dice_for_cases = [] case_list = [] sys.stdout.flush() # read the list path from the cross validation image_list = open(list_path).readlines() assert os.path.exists(args.load_path) state_dict = torch.load(args.load_path, map_location="cpu")['state_dict'] new_state_dict = OrderedDict() for key in state_dict.keys(): new_state_dict[key[7:]] = state_dict[key] state_dict = net.state_dict() print("Loading weights...") for k in list(new_state_dict.keys()): if k not in state_dict: del new_state_dict[k] state_dict.update(new_state_dict) net.load_state_dict(state_dict) net.cuda() net.eval() # test passed for the first case for i in range(0, len(image_list)): file_name = image_list[i].strip('\n') if '/' in file_name: file_name = os.path.basename(file_name)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True) model = VNet(args.n_channels, args.n_classes, input_size=64, pretrain=True).cuda() model_ema = VNet(args.n_channels, args.n_classes, input_size=64, pretrain=True).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0005) model = torch.nn.DataParallel(model) model_ema = torch.nn.DataParallel(model_ema) model_ema.load_state_dict(model.state_dict()) print("Model Initialized") logger = Logger(root_path) saver = Saver(root_path, save_freq=args.save_freq) if args.sampling == 'default': contrast = RGBMoCo(128, K=4096, T=args.temperature).cuda() elif args.sampling == 'layerwise': contrast = RGBMoCoNew(128, K=4096, T=args.temperature).cuda() else: raise ValueError("unsupported sampling method") criterion = torch.nn.CrossEntropyLoss() for epoch in range(args.start_epoch, args.epochs): pretrain(model, model_ema, train_loader, optimizer, logger, saver, args, epoch, contrast, criterion) adjust_learning_rate(args, optimizer, epoch)