def Train(train_root, train_csv, test_root, test_csv): # parameters args = parse_args() besttraindice = 0.0 # record record_params(args) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net = build_model(args.model_name, num_classes) # resume params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition) start_epoch = 0 history = { 'train_loss': [], 'test_loss': [], 'train_dice': [], 'test_dice': [] } end_epoch = start_epoch + args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net) net.to(device) # data img_size = args.img_size ## train3_multidomainl_normalcl train_aug = Compose([ Resize(size=(img_size, img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) ## test test_aug = train_aug train_dataset = kidney_seg(root=train_root, csv_file=train_csv, maskidentity=args.maskidentity, train=True, transform=train_aug) test_dataset = kidney_seg(root=test_root, csv_file=test_csv, maskidentity=args.maskidentity, train=False, transform=test_aug) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLoss(cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler = StepLR(optimizer, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Kidney Seg') for epoch in range(start_epoch, end_epoch): ts = time.time() # train3_multidomainl_normalcl net.train() train_loss = 0. train_dice = 0. train_count = 0 for batch_idx, (inputs, _, targets) in \ tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)): inputs = inputs.to(device) targets = targets.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_count += inputs.shape[0] train_loss += loss.item() * inputs.shape[0] train_dice += Dice_fn(outputs, targets).item() train_loss_epoch = train_loss / float(train_count) train_dice_epoch = train_dice / float(train_count) history['train_loss'].append(train_loss_epoch) history['train_dice'].append(train_dice_epoch) # test net.eval() test_loss = 0. test_dice = 0. test_count = 0 for batch_idx, (inputs, _, targets) in tqdm( enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) test_count += inputs.shape[0] test_loss += loss.item() * inputs.shape[0] test_dice += Dice_fn(outputs, targets).item() test_loss_epoch = test_loss / float(test_count) test_dice_epoch = test_dice / float(test_count) history['test_loss'].append(test_loss_epoch) history['test_dice'].append(test_dice_epoch) traineval_loss = 0. traineval_dice = 0. traineval_count = 0 for batch_idx, (inputs, _, targets) in tqdm( enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) traineval_count += inputs.shape[0] traineval_loss += loss.item() * inputs.shape[0] traineval_dice += Dice_fn(outputs, targets).item() traineval_loss_epoch = traineval_loss / float(traineval_count) traineval_dice_epoch = traineval_dice / float(traineval_count) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f ' '| traineval_dice: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch, train_dice_epoch, test_dice_epoch, traineval_dice_epoch, time_cost)) if args.lr_policy != 'None': scheduler.step() if traineval_dice_epoch > besttraindice: besttraindice = traineval_dice_epoch logging.info('Best Checkpoint {} Saving...'.format(epoch + 1)) save_model = net if torch.cuda.device_count() > 1: save_model = list(net.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss_epoch, 'dice': test_dice_epoch, 'epoch': epoch + 1, 'history': history } savecheckname = os.path.join( args.checkpoint, params_name.split('.pkl')[0] + '_besttraindice.' + params_name.split('.')[-1]) torch.save(state, savecheckname)
def Train(train_root, train_csv, test_csv, tempmaskfolder): makefolder(os.path.join(train_root, tempmaskfolder)) besttraindice = 0.0 changepointdice = 0.0 ascending = False # parameters args = parse_args() record_params(args) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net1 = build_model(args.model1_name, num_classes) net2 = build_model(args.model2_name, num_classes) # resume params1_name = '{}_warmup{}_temp{}_r{}_net1.pkl'.format( args.model1_name, args.warmup_epoch, args.temperature, args.repetition) params2_name = '{}_warmup{}_temp{}_r{}_net2.pkl'.format( args.model2_name, args.warmup_epoch, args.temperature, args.repetition) checkpoint1_path = os.path.join(args.checkpoint, params1_name) checkpoint2_path = os.path.join(args.checkpoint, params2_name) initializecheckpoint = torch.load(args.resumefile)['net'] net1.load_state_dict(initializecheckpoint) net2.load_state_dict(initializecheckpoint) start_epoch = 0 end_epoch = args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net1 = nn.DataParallel(net1) net2 = nn.DataParallel(net2) net1.to(device) net2.to(device) # data train_aug = Compose([ Resize(size=(args.img_size, args.img_size)), RandomRotate(args.rotation), RandomHorizontallyFlip(), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) test_aug = Compose([ Resize(size=(args.img_size, args.img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) train_dataset = kidney_seg(root=train_root, csv_file=train_csv, tempmaskfolder=tempmaskfolder, maskidentity=args.maskidentity, train=True, transform=train_aug) test_dataset = kidney_seg( root=train_root, csv_file=test_csv, tempmaskfolder=tempmaskfolder, maskidentity=args.maskidentity, train=False, transform=test_aug) # tempmaskfolder=tempmaskfolder, train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLossImage( cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') corrlosscriterion = MulticlassMSELoss(reduction='none').to(device) # define augmentation loss effect schedule rate_schedule = np.ones(args.num_epoch) optimizer1 = Adam(net1.parameters(), lr=args.lr, amsgrad=True) optimizer2 = Adam(net2.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler1 = StepLR(optimizer1, step_size=30, gamma=0.5) scheduler2 = StepLR(optimizer2, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler1 = PolyLR(optimizer1, max_epoch=end_epoch, power=0.9) scheduler2 = PolyLR(optimizer2, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Kidney Seg') for epoch in range(start_epoch, end_epoch): ts = time.time() if args.warmup_epoch == 0: rate_schedule[epoch] = 1.0 else: rate_schedule[epoch] = min( (float(epoch) / float(args.warmup_epoch))**2, 1.0) net1.train() net2.train() train_loss1 = 0. train_dice1 = 0. train_count = 0 train_loss2 = 0. train_dice2 = 0. for batch_idx, (inputs, augset, targets, targets1, targets2) in \ tqdm(enumerate(train_loader), total=int( len(train_loader.dataset) / args.batch_size)): # (inputs, augset, targets, targets1, targets2) net1.eval() net2.eval() augoutput1 = [] augoutput2 = [] for aug_idx in range(augset['augno'][0]): augimg = augset['img{}'.format(aug_idx + 1)].to(device) augoutput1.append(net1(augimg).detach()) augoutput2.append(net2(augimg).detach()) # augoutput1 = reverseaugbatch(augset, augoutput1, classno=num_classes) augoutput2 = reverseaugbatch(augset, augoutput2, classno=num_classes) for aug_idx in range(augset['augno'][0]): augmask1 = torch.nn.functional.softmax(augoutput1[aug_idx], dim=1) augmask2 = torch.nn.functional.softmax(augoutput2[aug_idx], dim=1) if aug_idx == 0: pseudo_label1 = augmask1 pseudo_label2 = augmask2 else: pseudo_label1 += augmask1 pseudo_label2 += augmask2 pseudo_label1 = pseudo_label1 / float(augset['augno'][0]) pseudo_label2 = pseudo_label2 / float(augset['augno'][0]) pseudo_label1 = sharpen(pseudo_label1, args.temperature) pseudo_label2 = sharpen(pseudo_label2, args.temperature) weightmap1 = 1.0 - 4.0 * pseudo_label1[:, 0, :, :] * pseudo_label1[:, 1, :, :] weightmap1 = weightmap1.unsqueeze(dim=1) weightmap2 = 1.0 - 4.0 * pseudo_label2[:, 0, :, :] * pseudo_label2[:, 1, :, :] weightmap2 = weightmap2.unsqueeze(dim=1) net1.train() net2.train() inputs = inputs.to(device) targets = targets.to(device) targets1 = targets1.to(device) targets2 = targets2.to(device) outputs1 = net1(inputs) outputs2 = net2(inputs) loss1_segpre = criterion(outputs1, targets2) loss2_segpre = criterion(outputs2, targets1) _, indx1 = loss1_segpre.sort() _, indx2 = loss2_segpre.sort() loss1_seg1 = criterion(outputs1[indx2[0:2], :, :, :], targets2[indx2[0:2], :, :]).mean() loss2_seg1 = criterion(outputs2[indx1[0:2], :, :, :], targets1[indx1[0:2], :, :]).mean() loss1_seg2 = criterion(outputs1[indx2[2:], :, :, :], targets2[indx2[2:], :, :]).mean() loss2_seg2 = criterion(outputs2[indx1[2:], :, :, :], targets1[indx1[2:], :, :]).mean() loss1_cor = weightmap2[indx2[2:], :, :, :] * corrlosscriterion( outputs1[indx2[2:], :, :, :], pseudo_label2[indx2[2:], :, :, :]) loss1_cor = loss1_cor.mean() loss1 = args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2) + \ args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor loss2_cor = weightmap1[indx1[2:], :, :, :] * corrlosscriterion( outputs2[indx1[2:], :, :, :], pseudo_label1[indx1[2:], :, :, :]) loss2_cor = loss2_cor.mean() loss2 = args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2) + \ args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor optimizer1.zero_grad() optimizer2.zero_grad() loss1.backward(retain_graph=True) optimizer1.step() loss2.backward() optimizer2.step() train_count += inputs.shape[0] train_loss1 += loss1.item() * inputs.shape[0] train_dice1 += Dice_fn(outputs1, targets2).item() train_loss2 += loss2.item() * inputs.shape[0] train_dice2 += Dice_fn(outputs2, targets1).item() train_loss1_epoch = train_loss1 / float(train_count) train_dice1_epoch = train_dice1 / float(train_count) train_loss2_epoch = train_loss2 / float(train_count) train_dice2_epoch = train_dice2 / float(train_count) # test net1.eval() net2.eval() test_loss1 = 0. test_dice1 = 0. test_loss2 = 0. test_dice2 = 0. test_count = 0 for batch_idx, (inputs, _, targets, targets1, targets2) in \ tqdm(enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) targets1 = targets1.to(device) targets2 = targets2.to(device) outputs1 = net1(inputs) outputs2 = net2(inputs) loss1 = criterion(outputs1, targets2).mean() loss2 = criterion(outputs2, targets1).mean() test_count += inputs.shape[0] test_loss1 += loss1.item() * inputs.shape[0] test_dice1 += Dice_fn(outputs1, targets2).item() test_loss2 += loss2.item() * inputs.shape[0] test_dice2 += Dice_fn(outputs2, targets1).item() test_loss1_epoch = test_loss1 / float(test_count) test_dice1_epoch = test_dice1 / float(test_count) test_loss2_epoch = test_loss2 / float(test_count) test_dice2_epoch = test_dice2 / float(test_count) traindices1 = torch.zeros(len(train_dataset)) traindices2 = torch.zeros(len(train_dataset)) generatedmask1 = [] generatedmask2 = [] for casecount in tqdm(range(len(train_dataset)), total=len(train_dataset)): sample = train_dataset.__getitem__(casecount) img = sample[0] mask1 = sample[4] mask2 = sample[3] with torch.no_grad(): img = torch.unsqueeze(img.to(device), 0) output1 = net1(img) output1 = F.softmax(output1, dim=1) output2 = net2(img) output2 = F.softmax(output2, dim=1) output1 = torch.argmax(output1, dim=1) output2 = torch.argmax(output2, dim=1) output1 = output1.squeeze().cpu() generatedoutput1 = output1.unsqueeze(dim=0).numpy() output2 = output2.squeeze().cpu() generatedoutput2 = output2.unsqueeze(dim=0).numpy() traindices1[casecount] = Dice2d(generatedoutput1, mask1.numpy()) traindices2[casecount] = Dice2d(generatedoutput2, mask2.numpy()) generatedmask1.append(generatedoutput1) generatedmask2.append(generatedoutput2) evaltrainavgdice1 = traindices1.sum() / float(len(train_dataset)) evaltrainavgdice2 = traindices2.sum() / float(len(train_dataset)) evaltrainavgdicetemp = (evaltrainavgdice1 + evaltrainavgdice2) / 2.0 maskannotations = { '1': train_dataset.mask1, '2': train_dataset.mask2, '3': train_dataset.mask3 } # update pseudolabel if (epoch + 1) <= args.warmup_epoch or (epoch + 1) % 10 == 0: avgdice = evaltrainavgdicetemp selected_samples = int(args.update_percent * len(train_dataset)) save_root = os.path.join(train_root, tempmaskfolder) _, sortidx1 = traindices1.sort() selectedidxs = sortidx1[:selected_samples] for selectedidx in selectedidxs: maskname = maskannotations['{}'.format(int( args.maskidentity))][selectedidx] savefolder = os.path.join(save_root, maskname.split('/')[-2]) makefolder(savefolder) save_name = os.path.join( savefolder, maskname.split('/')[-1].split('.')[0] + '_net1.nii.gz') save_data = generatedmask1[selectedidx] if save_data.sum() > 0: soutput = sitk.GetImageFromArray(save_data) sitk.WriteImage(soutput, save_name) logging.info('{} masks modified for net1'.format( len(selectedidxs))) _, sortidx2 = traindices2.sort() selectedidxs = sortidx2[:selected_samples] for selectedidx in selectedidxs: maskname = maskannotations['{}'.format(int( args.maskidentity))][selectedidx] savefolder = os.path.join(save_root, maskname.split('/')[-2]) makefolder(savefolder) save_name = os.path.join( savefolder, maskname.split('/')[-1].split('.')[0] + '_net2.nii.gz') save_data = generatedmask2[selectedidx] if save_data.sum() > 0: soutput = sitk.GetImageFromArray(save_data) sitk.WriteImage(soutput, save_name) logging.info('{} masks modify for net2'.format(len(selectedidxs))) if epoch > 0 and changepointdice < evaltrainavgdicetemp and ascending == False: ascending = True besttraindice = changepointdice if evaltrainavgdicetemp > besttraindice and ascending: besttraindice = evaltrainavgdicetemp logging.info('Best Checkpoint {} Saving...'.format(epoch + 1)) save_model = net1 if torch.cuda.device_count() > 1: save_model = list(net1.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss1_epoch, 'epoch': epoch + 1, } torch.save( state, '{}_besttraindice.pkl'.format( checkpoint1_path.split('.pkl')[0])) save_model = net2 if torch.cuda.device_count() > 1: save_model = list(net2.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss2_epoch, 'epoch': epoch + 1, } torch.save( state, '{}_besttraindice.pkl'.format( checkpoint2_path.split('.pkl')[0])) if not ascending: changepointdice = evaltrainavgdicetemp time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss1: %.3f | test_loss1: %.3f | ' 'train_dice1: %.3f | test_dice1: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss1_epoch, test_loss1_epoch, train_dice1_epoch, test_dice1_epoch, time_cost)) logging.info( 'epoch[%d/%d]: train_loss2: %.3f | test_loss2: %.3f | ' 'train_dice2: %.3f | test_dice2: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss2_epoch, test_loss2_epoch, train_dice2_epoch, test_dice2_epoch, time_cost)) logging.info( 'epoch[%d/%d]: evaltrain_dice1: %.3f | evaltrain_dice2: %.3f || time: %.1f' % (epoch + 1, end_epoch, evaltrainavgdice1, evaltrainavgdice2, time_cost)) net1.train() net2.train() if args.lr_policy != 'None': scheduler1.step() scheduler2.step()
def Train(train_root, train_csv, test_root, test_csv, traincase_csv, testcase_csv): # parameters args = parse_args() besttraincasedice = 0.0 train_cases = pd.read_csv(traincase_csv)['Image'].tolist() test_cases = pd.read_csv(testcase_csv)['Image'].tolist() # record record_params(args) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net = build_model(args.model_name, num_classes) params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition) start_epoch = 0 history = { 'train_loss': [], 'test_loss': [], 'train_dice': [], 'test_dice': [] } end_epoch = start_epoch + args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net) net.to(device) # data img_size = args.img_size ## train train_aug = Compose([ Resize(size=(img_size, img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) ## test test_aug = train_aug train_dataset = prostate_seg(root=train_root, csv_file=train_csv, transform=train_aug) test_dataset = prostate_seg(root=test_root, csv_file=test_csv, transform=test_aug) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLoss(cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler = StepLR(optimizer, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Prostate Seg') for epoch in range(start_epoch, end_epoch): ts = time.time() # train net.train() train_loss = 0. train_dice = 0. train_count = 0 for batch_idx, (inputs, _, targets) in \ tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)): inputs = inputs.to(device) targets = targets.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_count += inputs.shape[0] train_loss += loss.item() * inputs.shape[0] train_dice += Dice_fn(outputs, targets).item() train_loss_epoch = train_loss / float(train_count) train_dice_epoch = train_dice / float(train_count) history['train_loss'].append(train_loss_epoch) history['train_dice'].append(train_dice_epoch) # test net.eval() test_loss = 0. test_dice = 0. test_count = 0 for batch_idx, (inputs, _, targets) in tqdm( enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) test_count += inputs.shape[0] test_loss += loss.item() * inputs.shape[0] test_dice += Dice_fn(outputs, targets).item() test_loss_epoch = test_loss / float(test_count) test_dice_epoch = test_dice / float(test_count) history['test_loss'].append(test_loss_epoch) history['test_dice'].append(test_dice_epoch) testcasedices = torch.zeros(len(test_cases)) startimgslices = torch.zeros(len(test_cases)) for casecount in tqdm(range(len(test_cases)), total=len(test_cases)): caseidx = test_cases[casecount].split('.')[0] caseimg = [file for file in test_dataset.imgs if caseidx in file] caseimg.sort() casemask = [file for file in test_dataset.masks if caseidx in file] casemask.sort() generatedtarget = [] target = [] startcaseimg = int(torch.sum(startimgslices[:casecount + 1])) for imgidx in range(len(caseimg)): sample = test_dataset.__getitem__(imgidx + startcaseimg) input = sample[0] mask = sample[2] target.append(mask) with torch.no_grad(): input = torch.unsqueeze(input.to(device), 0) output = net(input) output = F.softmax(output, dim=1) output = torch.argmax(output, dim=1) output = output.squeeze().cpu().numpy() generatedtarget.append(output) target = np.stack(target, axis=-1) generatedtarget = np.stack(generatedtarget, axis=-1) generatedtarget_keeplargest = keep_largest_connected_components( generatedtarget) testcasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest, target) if casecount + 1 < len(test_cases): startimgslices[casecount + 1] = len(caseimg) testcasedice = testcasedices.sum() / float(len(test_cases)) traincasedices = torch.zeros(len(train_cases)) startimgslices = torch.zeros(len(train_cases)) generatedmask = [] for casecount in tqdm(range(len(train_cases)), total=len(train_cases)): caseidx = train_cases[casecount] caseimg = [file for file in train_dataset.imgs if caseidx in file] caseimg.sort() casemask = [ file for file in train_dataset.masks if caseidx in file ] casemask.sort() generatedtarget = [] target = [] startcaseimg = int(torch.sum(startimgslices[:casecount + 1])) for imgidx in range(len(caseimg)): sample = train_dataset.__getitem__(imgidx + startcaseimg) input = sample[0] mask = sample[2] target.append(mask) with torch.no_grad(): input = torch.unsqueeze(input.to(device), 0) output = net(input) output = F.softmax(output, dim=1) output = torch.argmax(output, dim=1) output = output.squeeze().cpu().numpy() generatedtarget.append(output) target = np.stack(target, axis=-1) generatedtarget = np.stack(generatedtarget, axis=-1) generatedtarget_keeplargest = keep_largest_connected_components( generatedtarget) traincasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest, target) generatedmask.append(generatedtarget_keeplargest) if casecount + 1 < len(train_cases): startimgslices[casecount + 1] = len(caseimg) traincasedice = traincasedices.sum() / float(len(train_cases)) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch, train_dice_epoch, test_dice_epoch, time_cost)) logging.info( 'epoch[%d/%d]: traincase_dice: %.3f | testcase_dice: %.3f || time: %.1f' % (epoch + 1, end_epoch, traincasedice, testcasedice, time_cost)) if args.lr_policy != 'None': scheduler.step() if traincasedice > besttraincasedice: besttraincasedice = traincasedice logging.info('Best Checkpoint {} Saving...'.format(epoch + 1)) save_model = net if torch.cuda.device_count() > 1: save_model = list(net.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss_epoch, 'dice': test_dice_epoch, 'epoch': epoch + 1, 'history': history } savecheckname = os.path.join( args.checkpoint, params_name.split('.pkl')[0] + '_besttraindice.' + params_name.split('.')[-1]) torch.save(state, savecheckname)
model_name = ['convnet', 'classifier', 'pan', 'mask_classifier'] optimizer = { 'convnet': optim.SGD(convnet.parameters(), lr=args.lr, weight_decay=1e-4), 'classifier': optim.SGD(classifier.parameters(), lr=args.lr, weight_decay=1e-4), 'pan': optim.SGD(pan.parameters(), lr=args.lr, weight_decay=1e-4), 'mask_classifier': optim.SGD(mask_classifier.parameters(), lr=args.lr, weight_decay=1e-4) } optimizer_lr_scheduler = { 'convnet': PolyLR(optimizer['convnet'], max_iter=args.epochs, power=0.9), 'classifier': PolyLR(optimizer['classifier'], max_iter=args.epochs, power=0.9), 'pan': PolyLR(optimizer['pan'], max_iter=args.epochs, power=0.9), 'mask_classifier': PolyLR(optimizer['mask_classifier'], max_iter=args.epochs, power=0.9) } best_acc = 0 for epoch in range(args.epochs): for m in model_name: optimizer_lr_scheduler[m].step(epoch) logging.info('Epoch:{:}'.format(epoch)) train(epoch, optimizer, training_loader) if epoch % 1 == 0:
#classifier = Classifier(in_features=2048, num_class=NUM_CLASS) pan = PAN(convnet.blocks[::-1]) mask_classifier = Mask_Classifier(in_features=256, num_class=(NUM_CLASS+1)) convnet.to(device) #classifier.to(device) pan.to(device) mask_classifier.to(device) #model_name = ['convnet', 'classifier', 'pan', 'mask_classifier'] model_name = ['convnet', 'pan', 'mask_classifier'] optimizer = {'convnet': optim.SGD(convnet.parameters(), lr=args.lr, weight_decay=1e-4), #'classifier': optim.SGD(classifier.parameters(), lr=args.lr, weight_decay=1e-4), 'pan': optim.SGD(pan.parameters(), lr=args.lr, weight_decay=1e-4), 'mask_classifier': optim.SGD(mask_classifier.parameters(), lr=args.lr, weight_decay=1e-4)} optimizer_lr_scheduler = {'convnet': PolyLR(optimizer['convnet'], max_iter=args.epochs, power=0.9), #'classifier': PolyLR(optimizer['classifier'], max_iter=args.epochs, power=0.9), 'pan': PolyLR(optimizer['pan'], max_iter=args.epochs, power=0.9), 'mask_classifier': PolyLR(optimizer['mask_classifier'], max_iter=args.epochs, power=0.9)} best_acc = 0 for epoch in range(args.epochs): for m in model_name: optimizer_lr_scheduler[m].step(epoch) logging.info('Epoch:{:}'.format(epoch)) train(epoch, optimizer, training_loader) if epoch % 5 == 0: test(test_loader)
def main(): global args, best_prec1 global cur_itrs args = parser.parse_args() print(args.mode) # STEP1: model if args.mode=='baseline_train': model = initialize_model(use_resnet=True, pretrained=False, nclasses=200) elif args.mode=='pretrain': model = deeplab_network.deeplabv3_resnet50(num_classes=args.num_classes, output_stride=args.output_stride, pretrained_backbone=False) set_bn_momentum(model.backbone, momentum=0.01) elif args.mode=='finetune': model = initialize_model(use_resnet=True, pretrained=False, nclasses=3) # load the pretrained model if args.pretrained_model: if os.path.isfile(args.pretrained_model): print("=> loading pretrained model '{}'".format(args.pretrained_model)) checkpoint = torch.load(args.pretrained_model) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded pretrained model '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) if torch.cuda.is_available: model = model.cuda() # STEP2: criterion and optimizer if args.mode in ['baseline_train', 'finetune']: criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # train_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) elif args.mode=='pretrain': criterion = nn.MSELoss() optimizer = torch.optim.SGD(params=[ {'params': model.backbone.parameters(), 'lr': 0.1*args.lr}, {'params': model.classifier.parameters(), 'lr': args.lr}, ], lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) scheduler = PolyLR(optimizer, args.total_itrs, power=0.9) # STEP3: loss/prec record if args.mode in ['baseline_train', 'finetune']: train_losses = [] train_top1s = [] train_top5s = [] test_losses = [] test_top1s = [] test_top5s = [] elif args.mode == 'pretrain': train_losses = [] test_losses = [] # STEP4: optionlly resume from a checkpoint if args.resume: print('resume') if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.mode in ['baseline_train', 'finetune']: checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) datafile = args.resume.split('.pth')[0] + '.npz' load_data = np.load(datafile) train_losses = list(load_data['train_losses']) train_top1s = list(load_data['train_top1s']) train_top5s = list(load_data['train_top5s']) test_losses = list(load_data['test_losses']) test_top1s = list(load_data['test_top1s']) test_top5s = list(load_data['test_top5s']) elif args.mode=='pretrain': checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) cur_itrs = checkpoint['cur_itrs'] datafile = args.resume.split('.pth')[0] + '.npz' load_data = np.load(datafile) train_losses = list(load_data['train_losses']) # test_losses = list(load_data['test_losses']) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # STEP5: train! if args.mode in ['baseline_train', 'finetune']: # data from utils import TinyImageNet_data_loader print('color_distortion:', color_distortion) train_loader, val_loader = TinyImageNet_data_loader(args.dataset, args.batch_size,color_distortion=args.color_distortion) # if evaluate the model if args.evaluate: print('evaluate this model on validation dataset') validate(val_loader, model, criterion, args.print_freq) return for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr) time1 = time.time() #timekeeping # train for one epoch model.train() loss, top1, top5 = train(train_loader, model, criterion, optimizer, epoch, args.print_freq) train_losses.append(loss) train_top1s.append(top1) train_top5s.append(top5) # evaluate on validation set model.eval() loss, prec1, prec5 = validate(val_loader, model, criterion, args.print_freq) test_losses.append(loss) test_top1s.append(prec1) test_top5s.append(prec5) # remember the best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'mode': args.mode, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict() }, is_best, args.mode + '_' + args.dataset +'.pth') np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses,train_top1s=train_top1s,train_top5s=train_top5s, test_losses=test_losses,test_top1s=test_top1s, test_top5s=test_top5s) # np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses) time2 = time.time() #timekeeping print('Elapsed time for epoch:',time2 - time1,'s') print('ETA of completion:',(time2 - time1)*(args.epochs - epoch - 1)/60,'minutes') print() elif args.mode=='pretrain': #data from utils import TinyImageNet_data_loader # args.dataset = 'tiny-imagenet-200' args.batch_size = 16 train_loader, val_loader = TinyImageNet_data_loader(args.dataset, args.batch_size, col=True) # if evaluate the model, show some results if args.evaluate: print('evaluate this model on validation dataset') visulization(val_loader, model, args.start_epoch) return # for epoch in range(args.start_epoch, args.epochs): epoch = 0 while True: if cur_itrs >= args.total_itrs: return # adjust_learning_rate(optimizer, epoch, args.lr) time1 = time.time() #timekeeping model.train() # train for one epoch # loss, _, _ = train(train_loader, model, criterion, optimizer, epoch, args.print_freq, colorization=True,scheduler=scheduler) # train_losses.append(loss) # model.eval() # # evaluate on validation set # loss, _, _ = validate(val_loader, model, criterion, args.print_freq, colorization=True) # test_losses.append(loss) save_checkpoint({ 'epoch': epoch + 1, 'mode': args.mode, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler':scheduler.state_dict(), "cur_itrs": cur_itrs }, True, args.mode + '_' + args.dataset +'.pth') np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses) # scheduler.step() time2 = time.time() #timekeeping print('Elapsed time for epoch:',time2 - time1,'s') print('ETA of completion:',(time2 - time1)*(args.total_itrs - cur_itrs - 1)/60,'minutes') print() epoch += 1
def Train(train_root, train_csv, test_csv, traincase_csv, testcase_csv, labelcase_csv, tempmaskfolder): makefolder(os.path.join(train_root, tempmaskfolder)) # parameters args = parse_args() # record record_params(args) train_cases = pd.read_csv(traincase_csv)['patient_case'].tolist() test_cases = pd.read_csv(testcase_csv)['patient_case'].tolist() label_cases = pd.read_csv(labelcase_csv)['patient_case'].tolist() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net1 = build_model(args.model_name, num_classes) net2 = build_model(args.model_name, num_classes) params1_name = '{}_temp{}_r{}_net1.pkl'.format(args.model_name, args.temperature, args.repetition) params2_name = '{}_temp{}_r{}_net2.pkl'.format(args.model_name, args.temperature, args.repetition) start_epoch = 0 end_epoch = args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net1 = nn.DataParallel(net1) net2 = nn.DataParallel(net2) net1.to(device) net2.to(device) # data train_aug = Compose([ Resize(size=(args.img_size, args.img_size)), RandomRotate(args.rotation), RandomHorizontallyFlip(), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) test_aug = Compose([ Resize(size=(args.img_size, args.img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) train_dataset = chaos_seg(root=train_root, csv_file=train_csv, tempmaskfolder=tempmaskfolder, transform=train_aug) test_dataset = chaos_seg(root=train_root, csv_file=test_csv, tempmaskfolder=tempmaskfolder, transform=test_aug) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLossImage( cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') corrlosscriterion = MulticlassMSELoss(reduction='none').to(device) # define augmentation loss effect schedule rate_schedule = np.ones(args.num_epoch) optimizer1 = Adam(net1.parameters(), lr=args.lr, amsgrad=True) optimizer2 = Adam(net2.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler1 = StepLR(optimizer1, step_size=30, gamma=0.5) scheduler2 = StepLR(optimizer2, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler1 = PolyLR(optimizer1, max_epoch=end_epoch, power=0.9) scheduler2 = PolyLR(optimizer2, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For CHAOS Seg') besttraincasedice = 0.0 for epoch in range(start_epoch, end_epoch): ts = time.time() rate_schedule[epoch] = min( (float(epoch) / float(args.warmup_epoch))**2, 1.0) # train net1.train() net2.train() train_loss1 = 0. train_dice1 = 0. train_count = 0 train_loss2 = 0. train_dice2 = 0. for batch_idx, (inphase, outphase, augset, targets, targets1, targets2) in \ tqdm(enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size)): augoutput1 = [] augoutput2 = [] for aug_idx in range(augset['augno'][0]): augimgin = augset['imgmodal1{}'.format(aug_idx + 1)].to(device) augimgout = augset['imgmodal2{}'.format(aug_idx + 1)].to(device) augoutput1.append(net1(augimgin, augimgout).detach()) augoutput2.append(net2(augimgin, augimgout).detach()) augoutput1 = reverseaug(augset, augoutput1, classno=num_classes) augoutput2 = reverseaug(augset, augoutput2, classno=num_classes) for aug_idx in range(augset['augno'][0]): augmask1 = torch.nn.functional.softmax(augoutput1[aug_idx], dim=1) augmask2 = torch.nn.functional.softmax(augoutput2[aug_idx], dim=1) if aug_idx == 0: pseudo_label1 = augmask1 pseudo_label2 = augmask2 else: pseudo_label1 += augmask1 pseudo_label2 += augmask2 pseudo_label1 = pseudo_label1 / float(augset['augno'][0]) pseudo_label2 = pseudo_label2 / float(augset['augno'][0]) pseudo_label1 = sharpen(pseudo_label1, args.temperature) pseudo_label2 = sharpen(pseudo_label2, args.temperature) weightmap1 = 1.0 - 4.0 * pseudo_label1[:, 0, :, :] * pseudo_label1[:, 1, :, :] weightmap1 = weightmap1.unsqueeze(dim=1) weightmap2 = 1.0 - 4.0 * pseudo_label2[:, 0, :, :] * pseudo_label2[:, 1, :, :] weightmap2 = weightmap2.unsqueeze(dim=1) inphase = inphase.to(device) outphase = outphase.to(device) targets1 = targets1[:, 1, :, :].to(device) targets2 = targets2[:, 1, :, :].to(device) optimizer1.zero_grad() optimizer2.zero_grad() outputs1 = net1(inphase, outphase) outputs2 = net2(inphase, outphase) loss1_segpre = criterion(outputs1, targets2) loss2_segpre = criterion(outputs2, targets1) _, indx1 = loss1_segpre.sort() _, indx2 = loss2_segpre.sort() loss1_seg1 = criterion(outputs1[indx2[0:2], :, :, :], targets2[indx2[0:2], :, :]).mean() loss2_seg1 = criterion(outputs2[indx1[0:2], :, :, :], targets1[indx1[0:2], :, :]).mean() loss1_seg2 = criterion(outputs1[indx2[2:], :, :, :], targets2[indx2[2:], :, :]).mean() loss2_seg2 = criterion(outputs2[indx1[2:], :, :, :], targets1[indx1[2:], :, :]).mean() loss1_cor = weightmap2[indx2[2:], :, :, :] * corrlosscriterion( outputs1[indx2[2:], :, :, :], pseudo_label2[indx2[2:], :, :, :]) loss1_cor = loss1_cor.mean() loss1 = args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2) + \ args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor loss2_cor = weightmap1[indx1[2:], :, :, :] * corrlosscriterion( outputs2[indx1[2:], :, :, :], pseudo_label1[indx1[2:], :, :, :]) loss2_cor = loss2_cor.mean() loss2 = args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2) + \ args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor loss1.backward(retain_graph=True) optimizer1.step() loss2.backward() optimizer2.step() train_count += inphase.shape[0] train_loss1 += loss1.item() * inphase.shape[0] train_dice1 += Dice_fn(outputs1, targets2).item() train_loss2 += loss2.item() * inphase.shape[0] train_dice2 += Dice_fn(outputs2, targets1).item() train_loss1_epoch = train_loss1 / float(train_count) train_dice1_epoch = train_dice1 / float(train_count) train_loss2_epoch = train_loss2 / float(train_count) train_dice2_epoch = train_dice2 / float(train_count) print(rate_schedule[epoch]) print(args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2)) print(args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor) print(args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2)) print(args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor) # test net1.eval() net2.eval() test_loss1 = 0. test_dice1 = 0. test_loss2 = 0. test_dice2 = 0. test_count = 0 for batch_idx, (inphase, outphase, augset, targets, targets1, targets2) in \ tqdm(enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): inphase = inphase.to(device) outphase = outphase.to(device) targets1 = targets1[:, 1, :, :].to(device) targets2 = targets2[:, 1, :, :].to(device) outputs1 = net1(inphase, outphase) outputs2 = net2(inphase, outphase) loss1 = criterion(outputs1, targets2).mean() loss2 = criterion(outputs2, targets1).mean() test_count += inphase.shape[0] test_loss1 += loss1.item() * inphase.shape[0] test_dice1 += Dice_fn(outputs1, targets2).item() test_loss2 += loss2.item() * inphase.shape[0] test_dice2 += Dice_fn(outputs2, targets1).item() test_loss1_epoch = test_loss1 / float(test_count) test_dice1_epoch = test_dice1 / float(test_count) test_loss2_epoch = test_loss2 / float(test_count) test_dice2_epoch = test_dice2 / float(test_count) testcasedices1 = torch.zeros(len(test_cases)) testcasedices2 = torch.zeros(len(test_cases)) startimgslices = torch.zeros(len(test_cases)) for casecount in tqdm(range(len(test_cases)), total=len(test_cases)): caseidx = test_cases[casecount] caseinphaseimg = [ file for file in test_dataset.t1inphase if int(file.split('/')[0]) == caseidx ] caseinphaseimg.sort() caseoutphaseimg = [ file for file in test_dataset.t1outphase if int(file.split('/')[0]) == caseidx ] caseoutphaseimg.sort() casemask = [ file for file in test_dataset.masks if int(file.split('/')[0]) == caseidx ] casemask.sort() generatedtarget1 = [] generatedtarget2 = [] target1 = [] target2 = [] startcaseimg = int(torch.sum(startimgslices[:casecount + 1])) for imgidx in range(len(caseinphaseimg)): assert caseinphaseimg[imgidx].split('/')[-1].split('.')[0] == \ casemask[imgidx].split('/')[-1].split('.')[0] assert caseinphaseimg[imgidx].split('/')[-1].split('-')[1] == \ caseoutphaseimg[imgidx].split('/')[-1].split('-')[1] assert int(caseinphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) == \ int(caseoutphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) + 1 sample = test_dataset.__getitem__(imgidx + startcaseimg) inphase = sample[0] outphase = sample[1] mask1 = sample[3] mask2 = sample[4] target1.append(mask1[1, :, :]) target2.append(mask2[1, :, :]) with torch.no_grad(): inphase = torch.unsqueeze(inphase.to(device), 0) outphase = torch.unsqueeze(outphase.to(device), 0) output1 = net1(inphase, outphase) output1 = F.softmax(output1, dim=1) output1 = torch.argmax(output1, dim=1) output1 = output1.squeeze().cpu().numpy() generatedtarget1.append(output1) output2 = net2(inphase, outphase) output2 = F.softmax(output2, dim=1) output2 = torch.argmax(output2, dim=1) output2 = output2.squeeze().cpu().numpy() generatedtarget2.append(output2) target1 = np.stack(target1, axis=-1) target2 = np.stack(target2, axis=-1) generatedtarget1 = np.stack(generatedtarget1, axis=-1) generatedtarget2 = np.stack(generatedtarget2, axis=-1) generatedtarget1_keeplargest = keep_largest_connected_components( generatedtarget1) generatedtarget2_keeplargest = keep_largest_connected_components( generatedtarget2) testcasedices1[casecount] = Dice3d_fn(generatedtarget1_keeplargest, target1) testcasedices2[casecount] = Dice3d_fn(generatedtarget2_keeplargest, target2) if casecount + 1 < len(test_cases): startimgslices[casecount + 1] = len(caseinphaseimg) testcasedice1 = testcasedices1.sum() / float(len(test_cases)) testcasedice2 = testcasedices2.sum() / float(len(test_cases)) traincasedices1 = torch.zeros(len(train_cases)) traincasedices2 = torch.zeros(len(train_cases)) # update pseudolabel startimgslices = torch.zeros(len(train_cases)) generatedmask1 = [] generatedmask2 = [] for casecount in tqdm(range(len(train_cases)), total=len(train_cases)): caseidx = train_cases[casecount] caseinphaseimg = [ file for file in train_dataset.t1inphase if int(file.split('/')[0]) == caseidx ] caseinphaseimg.sort() caseoutphaseimg = [ file for file in train_dataset.t1outphase if int(file.split('/')[0]) == caseidx ] caseoutphaseimg.sort() if caseidx in label_cases: casemask = [ file for file in train_dataset.masks if file.split('/')[0].isdigit() ] casemask = [ file for file in casemask if int(file.split('/')[0]) == caseidx ] else: casemask = [ file for file in train_dataset.masks if file.split('/')[-2].isdigit() ] casemask = [ file for file in casemask if int(file.split('/')[-2]) == caseidx ] casemask.sort() generatedtarget1 = [] generatedtarget2 = [] target1 = [] target2 = [] startcaseimg = int(torch.sum(startimgslices[:casecount + 1])) for imgidx in range(len(caseinphaseimg)): assert caseinphaseimg[imgidx].split('/')[-1].split('.')[0] == \ casemask[imgidx].split('/')[-1].split('.')[0] assert caseinphaseimg[imgidx].split('/')[-1].split('-')[1] == \ caseoutphaseimg[imgidx].split('/')[-1].split('-')[1] assert int(caseinphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) == \ int(caseoutphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) + 1 sample = train_dataset.__getitem__(imgidx + startcaseimg) inphase = sample[0] outphase = sample[1] mask1 = sample[3] mask2 = sample[4] target1.append(mask1[1, :, :]) target2.append(mask2[1, :, :]) with torch.no_grad(): inphase = torch.unsqueeze(inphase.to(device), 0) outphase = torch.unsqueeze(outphase.to(device), 0) output1 = net1(inphase, outphase) output1 = F.softmax(output1, dim=1) output1 = torch.argmax(output1, dim=1) output1 = output1.squeeze().cpu().numpy() generatedtarget1.append(output1) output2 = net2(inphase, outphase) output2 = F.softmax(output2, dim=1) output2 = torch.argmax(output2, dim=1) output2 = output2.squeeze().cpu().numpy() generatedtarget2.append(output2) target1 = np.stack(target1, axis=-1) target2 = np.stack(target2, axis=-1) generatedtarget1 = np.stack(generatedtarget1, axis=-1) generatedtarget2 = np.stack(generatedtarget2, axis=-1) generatedtarget1_keeplargest = keep_largest_connected_components( generatedtarget1) generatedtarget2_keeplargest = keep_largest_connected_components( generatedtarget2) traincasedices1[casecount] = Dice3d_fn( generatedtarget1_keeplargest, target1) traincasedices2[casecount] = Dice3d_fn( generatedtarget2_keeplargest, target2) generatedmask1.append(generatedtarget1_keeplargest) generatedmask2.append(generatedtarget2_keeplargest) if casecount + 1 < len(train_cases): startimgslices[casecount + 1] = len(caseinphaseimg) traincasedice1 = traincasedices1.sum() / float(len(train_cases)) traincasedice2 = traincasedices2.sum() / float(len(train_cases)) traincasediceavgtemp = (traincasedice1 + traincasedice2) / 2.0 if traincasediceavgtemp > besttraincasedice: besttraincasedice = traincasediceavgtemp logging.info('Best Checkpoint {} Saving...'.format(epoch + 1)) save_model = net1 if torch.cuda.device_count() > 1: save_model = list(net1.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss1_epoch, 'epoch': epoch + 1, } savecheckname = os.path.join( args.checkpoint, params1_name.split('.pkl')[0] + '_besttraincasedice.' + params1_name.split('.')[-1]) torch.save(state, savecheckname) save_model = net2 if torch.cuda.device_count() > 1: save_model = list(net2.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss2_epoch, 'epoch': epoch + 1, } savecheckname = os.path.join( args.checkpoint, params2_name.split('.pkl')[0] + '_besttraincasedicde.' + params2_name.split('.')[-1]) torch.save(state, savecheckname) if (epoch + 1) <= args.warmup_epoch or (epoch + 1) % 10 == 0: selected_samples = int(0.25 * len(train_cases)) save_root = os.path.join(train_root, tempmaskfolder) _, sortidx1 = traincasedices1.sort() selectedidxs = sortidx1[:selected_samples] for selectedidx in selectedidxs: caseidx = train_cases[selectedidx] if caseidx not in label_cases: caseinphaseimg = [ file for file in train_dataset.t1inphase if int(file.split('/')[0]) == caseidx ] caseinphaseimg.sort() caseoutphaseimg = [ file for file in train_dataset.t1outphase if int(file.split('/')[0]) == caseidx ] caseoutphaseimg.sort() casemask = [ file for file in train_dataset.masks if file.split('/')[-2].isdigit() ] casemask = [ file for file in casemask if int(file.split('/')[-2]) == caseidx ] casemask.sort() for imgidx in range(len(caseinphaseimg)): save_folder = os.path.join(save_root, str(caseidx)) makefolder(save_folder) save_name = os.path.join( save_folder, casemask[imgidx].split('/')[-1].split('.')[0] + '_net1.png') save_data = generatedmask1[selectedidx][:, :, imgidx] output_pil = save_data * 63 output_pil = Image.fromarray( output_pil.astype(np.uint8), 'L') output_pil.save(save_name) logging.info('Mask {} modify for net1'.format( [train_cases[i] for i in selectedidxs])) _, sortidx2 = traincasedices2.sort() selectedidxs = sortidx2[:selected_samples] for selectedidx in selectedidxs: caseidx = train_cases[selectedidx] if caseidx not in label_cases: caseinphaseimg = [ file for file in train_dataset.t1inphase if int(file.split('/')[0]) == caseidx ] caseinphaseimg.sort() caseoutphaseimg = [ file for file in train_dataset.t1outphase if int(file.split('/')[0]) == caseidx ] caseoutphaseimg.sort() casemask = [ file for file in train_dataset.masks if file.split('/')[-2].isdigit() ] casemask = [ file for file in casemask if int(file.split('/')[-2]) == caseidx ] casemask.sort() for imgidx in range(len(caseinphaseimg)): save_folder = os.path.join(save_root, str(caseidx)) makefolder(save_folder) save_name = os.path.join( save_folder, casemask[imgidx].split('/')[-1].split('.')[0] + '_net2.png') save_data = generatedmask2[selectedidx][:, :, imgidx] output_pil = save_data * 63 output_pil = Image.fromarray( output_pil.astype(np.uint8), 'L') output_pil.save(save_name) logging.info('Mask {} modify for net2'.format( [train_cases[i] for i in selectedidxs])) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss1: %.3f | test_loss1: %.3f | train_dice1: %.3f | test_dice1: %.3f || ' 'traincase_dice1: %.3f || testcase_dice1: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss1_epoch, test_loss1_epoch, train_dice1_epoch, test_dice1_epoch, traincasedice1, testcasedice1, time_cost)) logging.info( 'epoch[%d/%d]: train_loss2: %.3f | test_loss2: %.3f | train_dice2: %.3f | test_dice2: %.3f || ' 'traincase_dice2: %.3f || testcase_dice2: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss2_epoch, test_loss2_epoch, train_dice2_epoch, test_dice2_epoch, traincasedice2, testcasedice2, time_cost)) if args.lr_policy != 'None': scheduler1.step() scheduler2.step()
def Train(train_root, train_csv, test_csv): # parameters args = parse_args() record_params(args) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 4 net = build_model(args.model_name, num_classes, args.pretrain) # resume checkpoint_name_loss = os.path.join( args.checkpoint, args.params_name.split('.')[0] + '_loss.' + args.params_name.split('.')[-1]) checkpoint_name_acc = os.path.join( args.checkpoint, args.params_name.split('.')[0] + '_acc.' + args.params_name.split('.')[-1]) if args.resume != 0: logging.info('Resuming from checkpoint...') checkpoint = torch.load(checkpoint_name_loss) best_loss = checkpoint['loss'] best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] history = checkpoint['history'] net.load_state_dict(checkpoint['net']) else: best_loss = float('inf') best_acc = 0.0 start_epoch = 0 history = { 'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [] } end_epoch = start_epoch + args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net) net.to(device) # data img_size = args.img_size ## train train_aug = Compose([ Resize(size=(img_size, img_size)), RandomHorizontallyFlip(), RandomVerticallyFlip(), RandomRotate(90), ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) ## test # test_aug = train_aug test_aug = Compose([ Resize(size=(img_size, img_size)), ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) train_dataset = breast_classify_inbreast(root=train_root, csv_file=train_csv, transform=train_aug) test_dataset = breast_classify_inbreast(root=train_root, csv_file=test_csv, transform=test_aug) if args.weighted_sampling == 1: weights = torch.FloatTensor([1.0, 1.0, 1.5, 5.0]).to(device) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) else: weights = None train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, # num_workers=4, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) # loss function, optimizer and scheduler criterion = nn.NLLLoss(size_average=True, weight=weights).to(device) optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler = StepLR(optimizer, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Breast Density Classification') for epoch in range(start_epoch, end_epoch): ts = time.time() if args.lr_policy != 'None': scheduler.step() # train net.train() train_loss = 0. train_acc = 0. for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader), total=int(len(train_loader))): inputs = inputs.to(device) targets = targets.to(device) targets = targets.long() optimizer.zero_grad() outputs = net(inputs) loss = criterion(F.log_softmax(outputs, dim=1), targets) loss.backward() optimizer.step() train_loss += loss.item() accuracy = float(sum(outputs.argmax(dim=1) == targets)) train_acc += accuracy train_acc_epoch = train_acc / (len(train_loader.dataset)) train_loss_epoch = train_loss / (batch_idx + 1) history['train_loss'].append(train_loss_epoch) history['train_acc'].append(train_acc_epoch) # test net.eval() test_loss = 0. test_acc = 0. for batch_idx, (inputs, targets) in tqdm( enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size) + 1): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) targets = targets.long() outputs = net(inputs) loss = criterion(F.log_softmax(outputs, dim=1), targets) accuracy = float(sum(outputs.argmax(dim=1) == targets)) test_acc += accuracy test_loss += loss.item() test_loss_epoch = test_loss / (batch_idx + 1) test_acc_epoch = test_acc / (len(test_loader.dataset)) history['test_loss'].append(test_loss_epoch) history['test_acc'].append(test_acc_epoch) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss: %.3f | train_acc: %.3f | test_loss: %.3f | test_acc: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss_epoch, train_acc_epoch, test_loss_epoch, test_acc_epoch, time_cost)) # save checkpoint if test_loss_epoch < best_loss: logging.info('Loss checkpoint Saving...') save_model = net if torch.cuda.device_count() > 1: save_model = list(net.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss_epoch, 'acc': test_acc_epoch, 'epoch': epoch + 1, 'history': history } torch.save(state, checkpoint_name_loss) best_loss = test_loss_epoch if test_acc_epoch > best_acc: logging.info('Acc checkpoint Saving...') save_model = net if torch.cuda.device_count() > 1: save_model = list(net.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss_epoch, 'acc': test_acc_epoch, 'epoch': epoch + 1, 'history': history } torch.save(state, checkpoint_name_acc) best_acc = test_acc_epoch