def Train(train_root, train_csv, test_root, test_csv): # parameters args = parse_args() besttraindice = 0.0 # record record_params(args) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net = build_model(args.model_name, num_classes) # resume params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition) start_epoch = 0 history = { 'train_loss': [], 'test_loss': [], 'train_dice': [], 'test_dice': [] } end_epoch = start_epoch + args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net) net.to(device) # data img_size = args.img_size ## train3_multidomainl_normalcl train_aug = Compose([ Resize(size=(img_size, img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) ## test test_aug = train_aug train_dataset = kidney_seg(root=train_root, csv_file=train_csv, maskidentity=args.maskidentity, train=True, transform=train_aug) test_dataset = kidney_seg(root=test_root, csv_file=test_csv, maskidentity=args.maskidentity, train=False, transform=test_aug) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLoss(cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler = StepLR(optimizer, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Kidney Seg') for epoch in range(start_epoch, end_epoch): ts = time.time() # train3_multidomainl_normalcl net.train() train_loss = 0. train_dice = 0. train_count = 0 for batch_idx, (inputs, _, targets) in \ tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)): inputs = inputs.to(device) targets = targets.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_count += inputs.shape[0] train_loss += loss.item() * inputs.shape[0] train_dice += Dice_fn(outputs, targets).item() train_loss_epoch = train_loss / float(train_count) train_dice_epoch = train_dice / float(train_count) history['train_loss'].append(train_loss_epoch) history['train_dice'].append(train_dice_epoch) # test net.eval() test_loss = 0. test_dice = 0. test_count = 0 for batch_idx, (inputs, _, targets) in tqdm( enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) test_count += inputs.shape[0] test_loss += loss.item() * inputs.shape[0] test_dice += Dice_fn(outputs, targets).item() test_loss_epoch = test_loss / float(test_count) test_dice_epoch = test_dice / float(test_count) history['test_loss'].append(test_loss_epoch) history['test_dice'].append(test_dice_epoch) traineval_loss = 0. traineval_dice = 0. traineval_count = 0 for batch_idx, (inputs, _, targets) in tqdm( enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) traineval_count += inputs.shape[0] traineval_loss += loss.item() * inputs.shape[0] traineval_dice += Dice_fn(outputs, targets).item() traineval_loss_epoch = traineval_loss / float(traineval_count) traineval_dice_epoch = traineval_dice / float(traineval_count) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f ' '| traineval_dice: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch, train_dice_epoch, test_dice_epoch, traineval_dice_epoch, time_cost)) if args.lr_policy != 'None': scheduler.step() if traineval_dice_epoch > besttraindice: besttraindice = traineval_dice_epoch logging.info('Best Checkpoint {} Saving...'.format(epoch + 1)) save_model = net if torch.cuda.device_count() > 1: save_model = list(net.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss_epoch, 'dice': test_dice_epoch, 'epoch': epoch + 1, 'history': history } savecheckname = os.path.join( args.checkpoint, params_name.split('.pkl')[0] + '_besttraindice.' + params_name.split('.')[-1]) torch.save(state, savecheckname)
def Train(train_root, train_csv, test_root, test_csv, traincase_csv, testcase_csv): # parameters args = parse_args() besttraincasedice = 0.0 train_cases = pd.read_csv(traincase_csv)['Image'].tolist() test_cases = pd.read_csv(testcase_csv)['Image'].tolist() # record record_params(args) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net = build_model(args.model_name, num_classes) params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition) start_epoch = 0 history = { 'train_loss': [], 'test_loss': [], 'train_dice': [], 'test_dice': [] } end_epoch = start_epoch + args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net) net.to(device) # data img_size = args.img_size ## train train_aug = Compose([ Resize(size=(img_size, img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) ## test test_aug = train_aug train_dataset = prostate_seg(root=train_root, csv_file=train_csv, transform=train_aug) test_dataset = prostate_seg(root=test_root, csv_file=test_csv, transform=test_aug) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLoss(cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler = StepLR(optimizer, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Prostate Seg') for epoch in range(start_epoch, end_epoch): ts = time.time() # train net.train() train_loss = 0. train_dice = 0. train_count = 0 for batch_idx, (inputs, _, targets) in \ tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)): inputs = inputs.to(device) targets = targets.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_count += inputs.shape[0] train_loss += loss.item() * inputs.shape[0] train_dice += Dice_fn(outputs, targets).item() train_loss_epoch = train_loss / float(train_count) train_dice_epoch = train_dice / float(train_count) history['train_loss'].append(train_loss_epoch) history['train_dice'].append(train_dice_epoch) # test net.eval() test_loss = 0. test_dice = 0. test_count = 0 for batch_idx, (inputs, _, targets) in tqdm( enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) test_count += inputs.shape[0] test_loss += loss.item() * inputs.shape[0] test_dice += Dice_fn(outputs, targets).item() test_loss_epoch = test_loss / float(test_count) test_dice_epoch = test_dice / float(test_count) history['test_loss'].append(test_loss_epoch) history['test_dice'].append(test_dice_epoch) testcasedices = torch.zeros(len(test_cases)) startimgslices = torch.zeros(len(test_cases)) for casecount in tqdm(range(len(test_cases)), total=len(test_cases)): caseidx = test_cases[casecount].split('.')[0] caseimg = [file for file in test_dataset.imgs if caseidx in file] caseimg.sort() casemask = [file for file in test_dataset.masks if caseidx in file] casemask.sort() generatedtarget = [] target = [] startcaseimg = int(torch.sum(startimgslices[:casecount + 1])) for imgidx in range(len(caseimg)): sample = test_dataset.__getitem__(imgidx + startcaseimg) input = sample[0] mask = sample[2] target.append(mask) with torch.no_grad(): input = torch.unsqueeze(input.to(device), 0) output = net(input) output = F.softmax(output, dim=1) output = torch.argmax(output, dim=1) output = output.squeeze().cpu().numpy() generatedtarget.append(output) target = np.stack(target, axis=-1) generatedtarget = np.stack(generatedtarget, axis=-1) generatedtarget_keeplargest = keep_largest_connected_components( generatedtarget) testcasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest, target) if casecount + 1 < len(test_cases): startimgslices[casecount + 1] = len(caseimg) testcasedice = testcasedices.sum() / float(len(test_cases)) traincasedices = torch.zeros(len(train_cases)) startimgslices = torch.zeros(len(train_cases)) generatedmask = [] for casecount in tqdm(range(len(train_cases)), total=len(train_cases)): caseidx = train_cases[casecount] caseimg = [file for file in train_dataset.imgs if caseidx in file] caseimg.sort() casemask = [ file for file in train_dataset.masks if caseidx in file ] casemask.sort() generatedtarget = [] target = [] startcaseimg = int(torch.sum(startimgslices[:casecount + 1])) for imgidx in range(len(caseimg)): sample = train_dataset.__getitem__(imgidx + startcaseimg) input = sample[0] mask = sample[2] target.append(mask) with torch.no_grad(): input = torch.unsqueeze(input.to(device), 0) output = net(input) output = F.softmax(output, dim=1) output = torch.argmax(output, dim=1) output = output.squeeze().cpu().numpy() generatedtarget.append(output) target = np.stack(target, axis=-1) generatedtarget = np.stack(generatedtarget, axis=-1) generatedtarget_keeplargest = keep_largest_connected_components( generatedtarget) traincasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest, target) generatedmask.append(generatedtarget_keeplargest) if casecount + 1 < len(train_cases): startimgslices[casecount + 1] = len(caseimg) traincasedice = traincasedices.sum() / float(len(train_cases)) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch, train_dice_epoch, test_dice_epoch, time_cost)) logging.info( 'epoch[%d/%d]: traincase_dice: %.3f | testcase_dice: %.3f || time: %.1f' % (epoch + 1, end_epoch, traincasedice, testcasedice, time_cost)) if args.lr_policy != 'None': scheduler.step() if traincasedice > besttraincasedice: besttraincasedice = traincasedice logging.info('Best Checkpoint {} Saving...'.format(epoch + 1)) save_model = net if torch.cuda.device_count() > 1: save_model = list(net.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss_epoch, 'dice': test_dice_epoch, 'epoch': epoch + 1, 'history': history } savecheckname = os.path.join( args.checkpoint, params_name.split('.pkl')[0] + '_besttraindice.' + params_name.split('.')[-1]) torch.save(state, savecheckname)