def main(train_args): net = PSPNet(num_classes=voc.num_classes).cuda() if len(train_args['snapshot']) == 0: curr_epoch = 1 train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} else: print 'training resumes from ' + train_args['snapshot'] net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_simul_transform = simul_transforms.Compose([ simul_transforms.RandomSized(train_args['input_size']), simul_transforms.RandomRotate(10), simul_transforms.RandomHorizontallyFlip() ]) val_simul_transform = simul_transforms.Scale(train_args['input_size']) train_input_transform = standard_transforms.Compose([ extended_transforms.RandomGaussianBlur(), standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage(), ]) visualize = standard_transforms.Compose([ standard_transforms.Scale(400), standard_transforms.CenterCrop(400), standard_transforms.ToTensor() ]) train_set = voc.VOC('train', simul_transform=train_simul_transform, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True) val_set = voc.VOC('val', simul_transform=val_simul_transform, transform=val_input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False) criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda() optimizer = optim.SGD([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * train_args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} ], momentum=train_args['momentum']) if len(train_args['snapshot']) > 0: optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] optimizer.param_groups[1]['lr'] = train_args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n') train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
def Train(train_root, train_csv, test_root, test_csv, iter_time, checkpoint_name=None): # record localtime = time.asctime(time.localtime(time.time())) logging.info('Seg polyp (Data: %s)' % localtime) logging.info('\n') # parameters args = parse_args() logging.info('Parameters: ') logging.info('model name: %s' % args.model_name) logging.info('torch seed: %d' % args.torch_seed) logging.info('gpu order: %s' % args.gpu_order) logging.info('batch size: %d' % args.batch_size) logging.info('num epoch: %d' % args.num_epoch) logging.info('ite_start_time: %d' % args.ite_start_time) logging.info('ite_end_time: %d' % args.ite_end_time) logging.info('learing rate: %f' % args.lr) logging.info('loss: %s' % args.loss) logging.info('img_size: %s' % str(args.img_size)) logging.info('lr_policy: %s' % args.lr_policy) logging.info('resume: %s' % args.resume) logging.info('log_name: %s' % args.fold_num + args.log_name) logging.info('params_name: %s' % args.fold_num + args.params_name) logging.info('\n') os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) cudnn.benchmark = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net = build_model(args.model_name, num_classes) # resume if checkpoint_name is None: checkpoint_path = os.path.join(args.checkpoint, args.model_name) if not os.path.exists(checkpoint_path): os.mkdir(checkpoint_path) checkpoint_name = os.path.join(checkpoint_path, args.fold_num + args.params_name) if load_checkpoint: logging.info('Resuming from checkpoint...') checkpoint = torch.load(checkpoint_name) best_loss = checkpoint['loss'] best_dice = checkpoint['dice'] start_epoch = checkpoint['epoch'] history = checkpoint['history'] net.load_state_dict(checkpoint['net']) else: best_loss = float('inf') best_dice = 0 start_epoch = 0 history = {'train_loss': [], 'test_loss': [], 'test_dice': []} start_epoch = 0 end_epoch = start_epoch + args.num_epoch # if torch.cuda.device_count() > 1: # print("Let's use", torch.cuda.device_count(), "GPUs!") # net = nn.DataParallel(net) net.to(device) # data img_size = args.img_size ## train # train_aug = Compose([ # Resize(size=(img_size, img_size)), # ToTensor(), # Normalize(mean=(0.5, 0.5, 0.5), # std=(0.5, 0.5, 0.5))]) # RandomOrder if args.style == 'aug': train_img_aug = Compose_own([ # RandomAffine(90, shear=45), # RandomRotation(90), # RandomHorizontalFlip(), # ColorJitter(brightness=0.05), Resize(img_size), ToTensor() ]) train_mask_aug = Compose_own([ # RandomAffine(90, shear=45), # RandomRotation(90), # RandomHorizontalFlip(), # ColorJitter(brightness=0.05), Resize(img_size), ToTensor() ]) else: train_img_aug = Compose_own([Resize(img_size), ToTensor()]) train_mask_aug = Compose_own([Resize(img_size), ToTensor()]) ## test test_img_aug = Compose_own([Resize(size=img_size), ToTensor()]) test_mask_aug = Compose_own([Resize(size=img_size), ToTensor()]) train_dataset = poly_seg(root=train_root, csv_file=train_csv, img_transform=train_img_aug, mask_transform=train_mask_aug, iter_time=iter_time) test_dataset = poly_seg(root=test_root, csv_file=test_csv, img_transform=test_img_aug, mask_transform=test_mask_aug) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=8, shuffle=False, drop_last=False) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=8, shuffle=False, drop_last=False) # loss function, optimizer and scheduler if loss_type == 'ce': criterion = CrossEntropyLoss2d().to(device) elif loss_type == 'union': criterion = UnionLossWithCrossEntropyAndDiceLoss().to(device) elif loss_type == 'ce+size': criterion = UnionLossWithCrossEntropyAndSize().to(device) elif loss_type == 'ce+boundary': class_weight = torch.autograd.Variable(torch.FloatTensor([1, 10])) criterion = UnionLossWithCrossEntropyAndSize().to(device) criterion2 = UnionLossWithCrossEntropyAndDiceAndBoundary().to(device) else: print('Do not have this loss') optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True) # optimizer = SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) if args.lr_policy == 'StepLR': scheduler = StepLR(optimizer, step_size=100, gamma=0.5) # training process logging.info('Start Training For Polyp Seg') from skimage.segmentation import slic, mark_boundaries import numpy as np size_loss_epoch = args.size_loss_epoch for epoch in range(start_epoch, end_epoch): ts = time.time() scheduler.step() # train net.train() train_loss = 0. for batch_idx, (inputs, gts, mask_names) in tqdm( enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size) + 1): seg_labs = [] images = [] for input in inputs: input_tem = input.clone().numpy().astype(np.double) input_tem = (((input_tem * 0.5) + 0.5) * 255).astype(np.uint8) input_tem = np.transpose(input_tem, (1, 2, 0)) images.append(input_tem) # seg_map = segmentation.felzenszwalb(input_tem, scale=32, sigma=0.5, min_size=128) # image_path = os.path.join('../data/CVC-912/train/images', img_names) seg_map = slic(input_tem, n_segments=100, compactness=10) # out = mark_boundaries(input_tem, seg_map) seg_map = seg_map.flatten() seg_lab = [ np.where(seg_map == u_label)[0] for u_label in np.unique(seg_map) ] seg_labs.append(seg_lab) inputs = inputs.to(device) gts = gts.to(device) optimizer.zero_grad() outputs = net(inputs) # if len(torch.where(torch.isnan(outputs))[0]) != 0: # print('app nan:' + str(batch_idx)) #outputs 为输出结果 new_gts_tem = [] for i in range(len(inputs)): output = outputs[i] output = output.permute(1, 2, 0).view(-1, 2) # 一共2类 # output = output.permute(1, 2, 0).view(-1, args.mod_dim2) target = torch.argmax(output, 1) target = target * (gts[i].view(-1)) im_target = target.data.cpu().numpy() net_output = np.resize(im_target, (288, 384)) if not os.path.exists(os.path.join('output', str(epoch))): os.mkdir(os.path.join('output', str(epoch))) cv2.imwrite(os.path.join('output', str(epoch), mask_names[i]), net_output * 255) if epoch < args.size_loss_epoch: #gt用超像素变化 spixel_gt = gts[i].view(-1, 1).cpu().detach().numpy() for inds in seg_labs[i]: u_labels, hist = np.unique(spixel_gt[inds], return_counts=True) spixel_gt[inds] = u_labels[np.argmax(hist, 0)] if spixel_gt.sum() == 0: spixel_gt = gts[i].view(-1, 1).cpu().detach().numpy() target = torch.from_numpy(spixel_gt).long() target = torch.reshape(target, (288, 384)) if target.sum().numpy() == 0: target = gts[i].cpu() target = target.unsqueeze(0) new_gts_tem.append(target) else: '''refine''' for inds in seg_labs[i]: u_labels, hist = np.unique(im_target[inds], return_counts=True) im_target[inds] = u_labels[np.argmax(hist, 0)] if im_target.sum() == 0: im_target = gts[i].view(-1, 1).cpu().detach().numpy() im_target = np.resize(im_target, (288, 384)) spixel_gt = im_target target = torch.from_numpy(im_target).long() # gt = gts[i].cpu().clone().detach() # cv2.imwrite(os.path.join('123', img_names[i]), target.numpy() * 255) # target = torch.reshape(target,(288,384)) target = target.unsqueeze(0) new_gts_tem.append(target) if not os.path.exists(os.path.join('new_gt', str(epoch))): os.mkdir(os.path.join('new_gt', str(epoch))) cv2.imwrite(os.path.join('new_gt', str(epoch), mask_names[i]), spixel_gt.reshape((288, 384)) * 255) new_gt = torch.cat([x for x in new_gts_tem], 0) #new_gt = torch.cat((temps[0],temps[1]),0) new_gt = new_gt.to(device) # from utils.crf import dense_crf # crf_res = [] # for i in range(len(images)): # out = dense_crf(images[i], output_softmax[i][1].cpu().detach().numpy()) # crf_res.append(torch.from_numpy(out).unsqueeze(0)) # crf_results = torch.cat([x for x in crf_res]).to(device) if epoch < size_loss_epoch: loss = criterion(outputs, new_gt) else: loss = criterion2(outputs, new_gt, 0.01 * epoch, 1 - 0.01 * epoch) # if args.loss == 'ce-dice': # # loss = 2 * criterion1(outputs, targets) + criterion2(outputs, targets) # pass # elif loss_type == 'ce+boundary': # loss1 = criterion(outputs, new_gt) # loss2 = criterion2(outputs, new_gt) # loss = loss1 + 0.0001 * loss2 # else: # loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() train_loss_epoch = train_loss / (batch_idx + 1) history['train_loss'].append(train_loss_epoch) writer.add_scalar('training_loss', train_loss_epoch, epoch) # test net.eval() test_loss = 0. test_dice = 0. for batch_idx, (inputs, targets, images) in tqdm( enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size) + 1): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) outputs = net(inputs) if args.loss == 'ce-dice': # loss = 2 * criterion1(outputs, targets) + criterion2(outputs, targets) pass else: loss = criterion(outputs, targets) dice = dice_fn(outputs, targets) test_loss += loss.item() test_dice += dice.item() test_loss_epoch = test_loss / (batch_idx + 1) test_dice_epoch = test_dice / len(test_loader.dataset) history['test_loss'].append(test_loss_epoch) history['test_dice'].append(test_dice_epoch) writer.add_scalar('validation_loss', test_loss_epoch, epoch) writer.add_scalar('validation_dice', test_dice_epoch, epoch) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | test_dice: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch, test_dice_epoch, time_cost)) # save checkpoint if test_loss_epoch < best_loss: # if test_dice_epoch > best_dice: logging.info('Checkpoint Saving...') save_model = net # if torch.cuda.device_count() > 1: # save_model = list(net.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss_epoch, 'dice': test_dice_epoch, 'epoch': epoch + 1, 'history': history } torch.save(state, checkpoint_name) best_loss = test_loss_epoch # best_dice = test_dice_epoch # if test_dice_epoch > best_dice: # logging.info('Checkpoint Saving...') # # save_model = net # # if torch.cuda.device_count() > 1: # # save_model = list(net.children())[0] # state = { # 'net': save_model.state_dict(), # 'loss': test_loss_epoch, # 'dice': test_dice_epoch, # 'epoch': epoch + 1, # 'history': history # } # checkpoint_name_dice = os.path.join(checkpoint_path, 'dice_' + args.fold_num + args.params_name) # torch.save(state, checkpoint_name_dice) # best_dice = test_dice_epoch writer.close() return net
def main(train_args): net = FCN8s(num_classes=voc.num_classes).cuda() if len(train_args['snapshot']) == 0: curr_epoch = 1 train_args['best_record'] = { 'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: print('training resumes from ' + train_args['snapshot']) net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = { 'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11]) } net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage(), ]) visualize = standard_transforms.Compose([ standard_transforms.Scale(400), standard_transforms.CenterCrop(400), standard_transforms.ToTensor() ]) train_set = voc.VOC('train', transform=input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=1, num_workers=4, shuffle=True) val_set = voc.VOC('val', transform=input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False) criterion = CrossEntropyLoss2d(size_average=False, ignore_index=voc.ignore_label).cuda() optimizer = optim.Adam([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * train_args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay'] }], betas=(train_args['momentum'], 0.999)) if len(train_args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] optimizer.param_groups[1]['lr'] = train_args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n') scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True) for epoch in range(curr_epoch, train_args['epoch_num'] + 1): train(train_loader, net, criterion, optimizer, epoch, train_args) val_loss = validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize) scheduler.step(val_loss)
torch.save(model.state_dict(), '{}/models/{}_{}.pkl'.format(LOGDIR, args.model, args.start_epoch)) model.train() nparams = get_nparams(model) try: from torchsummary import summary summary(model,input_size=(1,640,400)) print("Max params:", 1024*1024/4.0) logger.write_summary(str(model.parameters)) except: print ("Torch summary not found !!!") optimizer = torch.optim.Adam(model.parameters(), lr = args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=5) criterion = CrossEntropyLoss2d() criterion_DICE = GeneralizedDiceLoss(softmax=True, reduction=True) criterion_SL = SurfaceLoss() Path2file = args.dataset train = IrisDataset(filepath = Path2file,split='train', transform = transform, **kwargs) valid = IrisDataset(filepath = Path2file , split='validation', transform = transform, **kwargs) trainloader = DataLoader(train, batch_size = args.bs, shuffle=True, num_workers = args.workers) validloader = DataLoader(valid, batch_size = args.bs, shuffle= False, num_workers = args.workers)
def __init__(self, is_testing, load_snapshot, snapshot_file, force_cpu): # Check if CUDA can be used if torch.cuda.is_available() and not force_cpu: print("CUDA detected. Running with GPU acceleration.") self.use_cuda = True elif force_cpu: print( "CUDA detected, but overriding with option '--cpu'. Running with only CPU." ) self.use_cuda = False else: print("CUDA is *NOT* detected. Running with only CPU.") self.use_cuda = False """ If decision was made on which method to use -> the code below can be modified Trend: supervised learning Both methods use different loss functions (classification loss --- Huber loss) """ # Fully convolutional classification network for supervised learning self.model = reactive_net(self.use_cuda) # Initialize classification loss push_num_classes = 3 # 0 - push, 1 - no change push, 2 - no loss push_class_weights = torch.ones(push_num_classes) push_class_weights[push_num_classes - 1] = 0 if self.use_cuda: self.push_criterion = CrossEntropyLoss2d( push_class_weights.cuda()).cuda() else: self.push_criterion = CrossEntropyLoss2d(push_class_weights) grasp_num_classes = 3 # 0 - grasp, 1 - failed grasp, 2 - no loss grasp_class_weights = torch.ones(grasp_num_classes) grasp_class_weights[grasp_num_classes - 1] = 0 if self.use_cuda: self.grasp_criterion = CrossEntropyLoss2d( grasp_class_weights.cuda()).cuda() else: self.grasp_criterion = CrossEntropyLoss2d(grasp_class_weights) # # Fully convolutional Q network for deep reinforcement learning # elif self.method == 'reinforcement': # self.model = reinforcement_net(self.use_cuda) # self.push_rewards = push_rewards # self.future_reward_discount = future_reward_discount # # # Initialize Huber loss # self.criterion = torch.nn.SmoothL1Loss(reduce=False) # Huber loss # if self.use_cuda: # self.criterion = self.criterion.cuda() # Load pre-trained model if load_snapshot: self.model.load_state_dict(torch.load(snapshot_file)) print('Pre-trained model snapshot loaded from: %s' % (snapshot_file)) # Convert model from CPU to GPU if self.use_cuda: self.model = self.model.cuda() # Set model to training mode self.model.train() # Initialize optimizer self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-4, momentum=0.9, weight_decay=2e-5) self.iteration = 0 # Initialize lists to save execution info and RL variables self.executed_action_log = [] self.label_value_log = [] self.reward_value_log = [] self.predicted_value_log = [] self.use_heuristic_log = [] self.is_exploit_log = [] self.clearance_log = []
def main(): net = PSPNet(num_classes=cityscapes.num_classes) if len(args['snapshot']) == 0: # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth'))) curr_epoch = 1 args['best_record'] = { 'epoch': 0, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = { 'epoch': int(split_snapshot[1]), 'iter': int(split_snapshot[3]), 'val_loss': float(split_snapshot[5]), 'acc': float(split_snapshot[7]), 'acc_cls': float(split_snapshot[9]), 'mean_iu': float(split_snapshot[11]), 'fwavacc': float(split_snapshot[13]) } net.cuda().train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(args['longer_size']), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip() ]) sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], cityscapes.ignore_label) train_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() visualize = standard_transforms.Compose([ standard_transforms.Scale(args['val_img_display_size']), standard_transforms.ToTensor() ]) train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform, sliding_crop=sliding_crop, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) val_set = cityscapes.CityScapes('fine', 'val', transform=val_input_transform, sliding_crop=sliding_crop, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False) criterion = CrossEntropyLoss2d( size_average=True, ignore_index=cityscapes.ignore_label).cuda() optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') train(train_loader, net, criterion, optimizer, curr_epoch, args, val_loader, visualize)
def __init__(self, opt): super(CycleMcdModel, self).__init__(opt) print('-------------- Networks initializing -------------') self.mode = None # specify the training losses you want to print out. The program will call base_model.get_current_losses self.lossNames = [ 'loss{}'.format(i) for i in [ 'GenA', 'DisA', 'CycleA', 'IdtA', 'DisB', 'GenB', 'CycleB', 'IdtB', 'Supervised', 'UnsupervisedClassifier', 'UnsupervisedFeature' ] ] self.lossGenA, self.lossDisA, self.lossCycleA, self.lossIdtA = 0, 0, 0, 0 self.lossGenB, self.lossDisB, self.lossCycleB, self.lossIdtB = 0, 0, 0, 0 self.lossSupervised, self.lossUnsupervisedClassifier, self.lossUnsupervisedFeature = 0, 0, 0 # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=opt.lsgan).to( opt.device) # lsgan = True use MSE loss, False use BCE loss self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionSeg = CrossEntropyLoss2d(opt) # 2d for each pixels self.criterionDis = Distance(opt) # specify the training miou you want to print out. The program will call base_model.get_current_mious self.miouNames = [ 'miou{}'.format(i) for i in ['SupervisedA', 'UnsupervisedA', 'SupervisedB', 'UnsupervisedB'] ] self.miouSupervisedA = IouEval(opt.nClass) self.miouUnsupervisedA = IouEval(opt.nClass) self.miouSupervisedB = IouEval(opt.nClass) self.miouUnsupervisedB = IouEval(opt.nClass) # specify the images you want to save/display. The program will call base_model.get_current_visuals # only image doesn't have prefix imageNamesA = [ 'realA', 'fakeA', 'recA', 'idtA', 'supervisedA', 'predSupervisedA', 'gndSupervisedA', 'unsupervisedA', 'predUnsupervisedA', 'gndUnsupervisedA' ] imageNamesB = [ 'realB', 'fakeB', 'recB', 'idtB', 'supervisedB', 'predSupervisedB', 'gndSupervisedB', 'unsupervisedB', 'predUnsupervisedB', 'gndUnsupervisedB' ] self.imageNames = imageNamesA + imageNamesB self.realA, self.fakeA, self.recA, self.idtA = None, None, None, None self.supervisedA, self.predSupervisedA, self.gndSupervisedA = None, None, None self.unsupervisedA, self.predUnsupervisedA, self.gndUnsupervisedA = None, None, None self.realB, self.fakeB, self.recB, self.idtB = None, None, None, None self.supervisedB, self.predSupervisedB, self.gndSupervisedB = None, None, None self.unsupervisedB, self.predUnsupervisedB, self.gndUnsupervisedB = None, None, None # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks # naming is by the input domain # Cycle gan model: 'GenA', 'DisA', 'GenB', 'DisB' # Mcd model : 'Features', 'Classifier1', 'Classifier2' self.modelNames = [ 'net{}'.format(i) for i in [ 'GenA', 'DisA', 'GenB', 'DisB', 'Features', 'Classifier1', 'Classifier2' ] ] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_RGB (G), G_D (F), D_RGB (D_Y), D_D (D_X) self.netGenA = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf, opt.which_model_netG, opt.norm, opt.dropout, opt.init_type, opt.init_gain, opt.gpuIds) self.netDisA = networks.define_D(opt.inputCh, opt.inputCh, opt.which_model_netD, opt.n_layers_D, opt.norm, not opt.lsgan, opt.init_type, opt.init_gain, opt.gpuIds) self.netGenB = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf, opt.which_model_netG, opt.norm, opt.dropout, opt.init_type, opt.init_gain, opt.gpuIds) self.netDisB = networks.define_D(opt.inputCh, opt.inputCh, opt.which_model_netD, opt.n_layers_D, opt.norm, not opt.lsgan, opt.init_type, opt.init_gain, opt.gpuIds) self.netFeatures = self.initNet( DRNSegBase(model_name=opt.segNet, n_class=opt.nClass, input_ch=opt.inputCh)) self.netClassifier1 = self.initNet( DRNSegPixelClassifier(n_class=opt.nClass)) self.netClassifier2 = self.initNet( DRNSegPixelClassifier(n_class=opt.nClass)) self.set_requires_grad([ self.netGenA, self.netGenB, self.netDisA, self.netDisB, self.netFeatures, self.netClassifier1, self.netClassifier2 ], True) # define image pool self.fakeAPool = ImagePool(opt.pool_size) self.fakeBPool = ImagePool(opt.pool_size) # initialize optimizers self.optimizerG = getOptimizer(itertools.chain( self.netGenA.parameters(), self.netGenB.parameters()), opt=opt.cycleOpt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizerD = getOptimizer(itertools.chain( self.netDisA.parameters(), self.netDisB.parameters()), opt=opt.cycleOpt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizerF = getOptimizer(itertools.chain( self.netFeatures.parameters()), opt=opt.mcdOpt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizerC = getOptimizer(itertools.chain( self.netClassifier1.parameters(), self.netClassifier2.parameters()), opt=opt.mcdOpt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizers = [] self.optimizers.append(self.optimizerG) self.optimizers.append(self.optimizerD) self.optimizers.append(self.optimizerF) self.optimizers.append(self.optimizerC) self.colorize = Colorize() print('--------------------------------------------------')
def Train(train_root, train_csv, test_csv): # parameters args = parse_args() # record record_params(args) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net = build_model(args.model_name, num_classes) params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition) start_epoch = 0 history = { 'train_loss': [], 'test_loss': [], 'train_dice': [], 'test_dice': [] } end_epoch = start_epoch + args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net) net.to(device) # data train_aug = Compose([ Resize(size=(args.img_size, args.img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) test_aug = Compose([ Resize(size=(args.img_size, args.img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) train_dataset = breast_seg(root=train_root, csv_file=train_csv, transform=train_aug) test_dataset = breast_seg(root=train_root, csv_file=test_csv, transform=test_aug) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLoss(cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler = StepLR(optimizer, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Breast Seg') besttraindice = 0. for epoch in range(start_epoch, end_epoch): ts = time.time() net.train() for batch_idx, (imgs, _, targets) in tqdm( enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size)): imgs = imgs.to(device) targets = targets.to(device) optimizer.zero_grad() outputs = net(imgs) loss = criterion(outputs, targets) loss.backward() optimizer.step() # test net.eval() test_loss = 0. test_dice = 0. test_count = 0 for batch_idx, (imgs, _, targets) in tqdm( enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): imgs = imgs.to(device) targets = targets.to(device) outputs = net(imgs) loss = criterion(outputs, targets).mean() test_count += imgs.shape[0] test_loss += loss.item() * imgs.shape[0] test_dice += Dice_fn(outputs, targets).item() test_loss_epoch = test_loss / float(test_count) test_dice_epoch = test_dice / float(test_count) history['test_loss'].append(test_loss_epoch) history['test_dice'].append(test_dice_epoch) train_loss = 0. train_dice = 0. train_count = 0 for batch_idx, (imgs, _, targets) in tqdm( enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size)): with torch.no_grad(): imgs = imgs.to(device) targets = targets.to(device) outputs = net(imgs) loss = criterion(outputs, targets).mean() train_count += imgs.shape[0] train_loss += loss.item() * imgs.shape[0] train_dice += Dice_fn(outputs, targets).item() train_loss_epoch = train_loss / float(train_count) train_dice_epoch = train_dice / float(train_count) history['train_loss'].append(train_loss_epoch) history['train_dice'].append(train_dice_epoch) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch, train_dice_epoch, test_dice_epoch, time_cost)) if args.lr_policy != 'None': scheduler.step() # save checkpoint if train_dice_epoch > besttraindice: besttraindice = train_dice_epoch logging.info('Besttraindice Checkpoint {} Saving...'.format(epoch + 1)) save_model = net if torch.cuda.device_count() > 1: save_model = list(net.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss_epoch, 'dice': test_dice_epoch, 'epoch': epoch + 1, 'history': history } savecheckname = os.path.join( args.checkpoint, params_name.split('.pkl')[0] + '_besttraindice.' + params_name.split('.')[-1]) torch.save(state, savecheckname)
def Train(train_root, train_csv, test_csv, tempmaskfolder): makefolder(os.path.join(train_root, tempmaskfolder)) besttraindice = 0.0 changepointdice = 0.0 ascending = False # parameters args = parse_args() record_params(args) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net1 = build_model(args.model1_name, num_classes) net2 = build_model(args.model2_name, num_classes) # resume params1_name = '{}_warmup{}_temp{}_r{}_net1.pkl'.format( args.model1_name, args.warmup_epoch, args.temperature, args.repetition) params2_name = '{}_warmup{}_temp{}_r{}_net2.pkl'.format( args.model2_name, args.warmup_epoch, args.temperature, args.repetition) checkpoint1_path = os.path.join(args.checkpoint, params1_name) checkpoint2_path = os.path.join(args.checkpoint, params2_name) initializecheckpoint = torch.load(args.resumefile)['net'] net1.load_state_dict(initializecheckpoint) net2.load_state_dict(initializecheckpoint) start_epoch = 0 end_epoch = args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net1 = nn.DataParallel(net1) net2 = nn.DataParallel(net2) net1.to(device) net2.to(device) # data train_aug = Compose([ Resize(size=(args.img_size, args.img_size)), RandomRotate(args.rotation), RandomHorizontallyFlip(), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) test_aug = Compose([ Resize(size=(args.img_size, args.img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) train_dataset = kidney_seg(root=train_root, csv_file=train_csv, tempmaskfolder=tempmaskfolder, maskidentity=args.maskidentity, train=True, transform=train_aug) test_dataset = kidney_seg( root=train_root, csv_file=test_csv, tempmaskfolder=tempmaskfolder, maskidentity=args.maskidentity, train=False, transform=test_aug) # tempmaskfolder=tempmaskfolder, train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLossImage( cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') corrlosscriterion = MulticlassMSELoss(reduction='none').to(device) # define augmentation loss effect schedule rate_schedule = np.ones(args.num_epoch) optimizer1 = Adam(net1.parameters(), lr=args.lr, amsgrad=True) optimizer2 = Adam(net2.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler1 = StepLR(optimizer1, step_size=30, gamma=0.5) scheduler2 = StepLR(optimizer2, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler1 = PolyLR(optimizer1, max_epoch=end_epoch, power=0.9) scheduler2 = PolyLR(optimizer2, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Kidney Seg') for epoch in range(start_epoch, end_epoch): ts = time.time() if args.warmup_epoch == 0: rate_schedule[epoch] = 1.0 else: rate_schedule[epoch] = min( (float(epoch) / float(args.warmup_epoch))**2, 1.0) net1.train() net2.train() train_loss1 = 0. train_dice1 = 0. train_count = 0 train_loss2 = 0. train_dice2 = 0. for batch_idx, (inputs, augset, targets, targets1, targets2) in \ tqdm(enumerate(train_loader), total=int( len(train_loader.dataset) / args.batch_size)): # (inputs, augset, targets, targets1, targets2) net1.eval() net2.eval() augoutput1 = [] augoutput2 = [] for aug_idx in range(augset['augno'][0]): augimg = augset['img{}'.format(aug_idx + 1)].to(device) augoutput1.append(net1(augimg).detach()) augoutput2.append(net2(augimg).detach()) # augoutput1 = reverseaugbatch(augset, augoutput1, classno=num_classes) augoutput2 = reverseaugbatch(augset, augoutput2, classno=num_classes) for aug_idx in range(augset['augno'][0]): augmask1 = torch.nn.functional.softmax(augoutput1[aug_idx], dim=1) augmask2 = torch.nn.functional.softmax(augoutput2[aug_idx], dim=1) if aug_idx == 0: pseudo_label1 = augmask1 pseudo_label2 = augmask2 else: pseudo_label1 += augmask1 pseudo_label2 += augmask2 pseudo_label1 = pseudo_label1 / float(augset['augno'][0]) pseudo_label2 = pseudo_label2 / float(augset['augno'][0]) pseudo_label1 = sharpen(pseudo_label1, args.temperature) pseudo_label2 = sharpen(pseudo_label2, args.temperature) weightmap1 = 1.0 - 4.0 * pseudo_label1[:, 0, :, :] * pseudo_label1[:, 1, :, :] weightmap1 = weightmap1.unsqueeze(dim=1) weightmap2 = 1.0 - 4.0 * pseudo_label2[:, 0, :, :] * pseudo_label2[:, 1, :, :] weightmap2 = weightmap2.unsqueeze(dim=1) net1.train() net2.train() inputs = inputs.to(device) targets = targets.to(device) targets1 = targets1.to(device) targets2 = targets2.to(device) outputs1 = net1(inputs) outputs2 = net2(inputs) loss1_segpre = criterion(outputs1, targets2) loss2_segpre = criterion(outputs2, targets1) _, indx1 = loss1_segpre.sort() _, indx2 = loss2_segpre.sort() loss1_seg1 = criterion(outputs1[indx2[0:2], :, :, :], targets2[indx2[0:2], :, :]).mean() loss2_seg1 = criterion(outputs2[indx1[0:2], :, :, :], targets1[indx1[0:2], :, :]).mean() loss1_seg2 = criterion(outputs1[indx2[2:], :, :, :], targets2[indx2[2:], :, :]).mean() loss2_seg2 = criterion(outputs2[indx1[2:], :, :, :], targets1[indx1[2:], :, :]).mean() loss1_cor = weightmap2[indx2[2:], :, :, :] * corrlosscriterion( outputs1[indx2[2:], :, :, :], pseudo_label2[indx2[2:], :, :, :]) loss1_cor = loss1_cor.mean() loss1 = args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2) + \ args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor loss2_cor = weightmap1[indx1[2:], :, :, :] * corrlosscriterion( outputs2[indx1[2:], :, :, :], pseudo_label1[indx1[2:], :, :, :]) loss2_cor = loss2_cor.mean() loss2 = args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2) + \ args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor optimizer1.zero_grad() optimizer2.zero_grad() loss1.backward(retain_graph=True) optimizer1.step() loss2.backward() optimizer2.step() train_count += inputs.shape[0] train_loss1 += loss1.item() * inputs.shape[0] train_dice1 += Dice_fn(outputs1, targets2).item() train_loss2 += loss2.item() * inputs.shape[0] train_dice2 += Dice_fn(outputs2, targets1).item() train_loss1_epoch = train_loss1 / float(train_count) train_dice1_epoch = train_dice1 / float(train_count) train_loss2_epoch = train_loss2 / float(train_count) train_dice2_epoch = train_dice2 / float(train_count) # test net1.eval() net2.eval() test_loss1 = 0. test_dice1 = 0. test_loss2 = 0. test_dice2 = 0. test_count = 0 for batch_idx, (inputs, _, targets, targets1, targets2) in \ tqdm(enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets = targets.to(device) targets1 = targets1.to(device) targets2 = targets2.to(device) outputs1 = net1(inputs) outputs2 = net2(inputs) loss1 = criterion(outputs1, targets2).mean() loss2 = criterion(outputs2, targets1).mean() test_count += inputs.shape[0] test_loss1 += loss1.item() * inputs.shape[0] test_dice1 += Dice_fn(outputs1, targets2).item() test_loss2 += loss2.item() * inputs.shape[0] test_dice2 += Dice_fn(outputs2, targets1).item() test_loss1_epoch = test_loss1 / float(test_count) test_dice1_epoch = test_dice1 / float(test_count) test_loss2_epoch = test_loss2 / float(test_count) test_dice2_epoch = test_dice2 / float(test_count) traindices1 = torch.zeros(len(train_dataset)) traindices2 = torch.zeros(len(train_dataset)) generatedmask1 = [] generatedmask2 = [] for casecount in tqdm(range(len(train_dataset)), total=len(train_dataset)): sample = train_dataset.__getitem__(casecount) img = sample[0] mask1 = sample[4] mask2 = sample[3] with torch.no_grad(): img = torch.unsqueeze(img.to(device), 0) output1 = net1(img) output1 = F.softmax(output1, dim=1) output2 = net2(img) output2 = F.softmax(output2, dim=1) output1 = torch.argmax(output1, dim=1) output2 = torch.argmax(output2, dim=1) output1 = output1.squeeze().cpu() generatedoutput1 = output1.unsqueeze(dim=0).numpy() output2 = output2.squeeze().cpu() generatedoutput2 = output2.unsqueeze(dim=0).numpy() traindices1[casecount] = Dice2d(generatedoutput1, mask1.numpy()) traindices2[casecount] = Dice2d(generatedoutput2, mask2.numpy()) generatedmask1.append(generatedoutput1) generatedmask2.append(generatedoutput2) evaltrainavgdice1 = traindices1.sum() / float(len(train_dataset)) evaltrainavgdice2 = traindices2.sum() / float(len(train_dataset)) evaltrainavgdicetemp = (evaltrainavgdice1 + evaltrainavgdice2) / 2.0 maskannotations = { '1': train_dataset.mask1, '2': train_dataset.mask2, '3': train_dataset.mask3 } # update pseudolabel if (epoch + 1) <= args.warmup_epoch or (epoch + 1) % 10 == 0: avgdice = evaltrainavgdicetemp selected_samples = int(args.update_percent * len(train_dataset)) save_root = os.path.join(train_root, tempmaskfolder) _, sortidx1 = traindices1.sort() selectedidxs = sortidx1[:selected_samples] for selectedidx in selectedidxs: maskname = maskannotations['{}'.format(int( args.maskidentity))][selectedidx] savefolder = os.path.join(save_root, maskname.split('/')[-2]) makefolder(savefolder) save_name = os.path.join( savefolder, maskname.split('/')[-1].split('.')[0] + '_net1.nii.gz') save_data = generatedmask1[selectedidx] if save_data.sum() > 0: soutput = sitk.GetImageFromArray(save_data) sitk.WriteImage(soutput, save_name) logging.info('{} masks modified for net1'.format( len(selectedidxs))) _, sortidx2 = traindices2.sort() selectedidxs = sortidx2[:selected_samples] for selectedidx in selectedidxs: maskname = maskannotations['{}'.format(int( args.maskidentity))][selectedidx] savefolder = os.path.join(save_root, maskname.split('/')[-2]) makefolder(savefolder) save_name = os.path.join( savefolder, maskname.split('/')[-1].split('.')[0] + '_net2.nii.gz') save_data = generatedmask2[selectedidx] if save_data.sum() > 0: soutput = sitk.GetImageFromArray(save_data) sitk.WriteImage(soutput, save_name) logging.info('{} masks modify for net2'.format(len(selectedidxs))) if epoch > 0 and changepointdice < evaltrainavgdicetemp and ascending == False: ascending = True besttraindice = changepointdice if evaltrainavgdicetemp > besttraindice and ascending: besttraindice = evaltrainavgdicetemp logging.info('Best Checkpoint {} Saving...'.format(epoch + 1)) save_model = net1 if torch.cuda.device_count() > 1: save_model = list(net1.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss1_epoch, 'epoch': epoch + 1, } torch.save( state, '{}_besttraindice.pkl'.format( checkpoint1_path.split('.pkl')[0])) save_model = net2 if torch.cuda.device_count() > 1: save_model = list(net2.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss2_epoch, 'epoch': epoch + 1, } torch.save( state, '{}_besttraindice.pkl'.format( checkpoint2_path.split('.pkl')[0])) if not ascending: changepointdice = evaltrainavgdicetemp time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss1: %.3f | test_loss1: %.3f | ' 'train_dice1: %.3f | test_dice1: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss1_epoch, test_loss1_epoch, train_dice1_epoch, test_dice1_epoch, time_cost)) logging.info( 'epoch[%d/%d]: train_loss2: %.3f | test_loss2: %.3f | ' 'train_dice2: %.3f | test_dice2: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss2_epoch, test_loss2_epoch, train_dice2_epoch, test_dice2_epoch, time_cost)) logging.info( 'epoch[%d/%d]: evaltrain_dice1: %.3f | evaltrain_dice2: %.3f || time: %.1f' % (epoch + 1, end_epoch, evaltrainavgdice1, evaltrainavgdice2, time_cost)) net1.train() net2.train() if args.lr_policy != 'None': scheduler1.step() scheduler2.step()
def Train(train_root, train_csv, test_csv, traincase_csv, testcase_csv, labelcase_csv, tempmaskfolder): makefolder(os.path.join(train_root, tempmaskfolder)) # parameters args = parse_args() # record record_params(args) train_cases = pd.read_csv(traincase_csv)['Image'].tolist() train_masks = pd.read_csv(traincase_csv)['Mask'].tolist() test_cases = pd.read_csv(testcase_csv)['Image'].tolist() label_cases = pd.read_csv(labelcase_csv)['Image'].tolist() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order torch.manual_seed(args.torch_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.torch_seed) np.random.seed(args.torch_seed) random.seed(args.torch_seed) if args.cudnn == 0: cudnn.benchmark = False else: cudnn.benchmark = True cudnn.deterministic = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = 2 net1 = build_model(args.model_name, num_classes) net2 = build_model(args.model_name, num_classes) params1_name = '{}_temp{}_r{}_net1.pkl'.format(args.model_name, args.temperature, args.repetition) params2_name = '{}_temp{}_r{}_net2.pkl'.format(args.model_name, args.temperature, args.repetition) start_epoch = 0 end_epoch = args.num_epoch if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") net1 = nn.DataParallel(net1) net2 = nn.DataParallel(net2) net1.to(device) net2.to(device) # data train_aug = Compose([ Resize(size=(args.img_size, args.img_size)), RandomRotate(args.rotation), RandomHorizontallyFlip(), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) test_aug = Compose([ Resize(size=(args.img_size, args.img_size)), ToTensor(), Normalize(mean=args.data_mean, std=args.data_std) ]) train_dataset = prostate_seg(root=train_root, csv_file=train_csv, tempmaskfolder=tempmaskfolder, transform=train_aug) test_dataset = prostate_seg(root=train_root, csv_file=test_csv, tempmaskfolder=tempmaskfolder, transform=test_aug) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # loss function, optimizer and scheduler cedice_weight = torch.tensor(args.cedice_weight) ceclass_weight = torch.tensor(args.ceclass_weight) diceclass_weight = torch.tensor(args.diceclass_weight) if args.loss == 'ce': criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device) elif args.loss == 'dice': criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device) elif args.loss == 'cedice': criterion = CEMDiceLossImage( cediceweight=cedice_weight, ceclassweight=ceclass_weight, diceclassweight=diceclass_weight).to(device) else: print('Do not have this loss') corrlosscriterion = MulticlassMSELoss(reduction='none').to(device) # define augmentation loss effect schedule rate_schedule = np.ones(args.num_epoch) optimizer1 = Adam(net1.parameters(), lr=args.lr, amsgrad=True) optimizer2 = Adam(net2.parameters(), lr=args.lr, amsgrad=True) ## scheduler if args.lr_policy == 'StepLR': scheduler1 = StepLR(optimizer1, step_size=30, gamma=0.5) scheduler2 = StepLR(optimizer2, step_size=30, gamma=0.5) if args.lr_policy == 'PolyLR': scheduler1 = PolyLR(optimizer1, max_epoch=end_epoch, power=0.9) scheduler2 = PolyLR(optimizer2, max_epoch=end_epoch, power=0.9) # training process logging.info('Start Training For Prostate Seg') besttraincasedice = 0.0 for epoch in range(start_epoch, end_epoch): ts = time.time() rate_schedule[epoch] = min( (float(epoch) / float(args.warmup_epoch))**2, 1.0) # train net1.train() net2.train() train_loss1 = 0. train_dice1 = 0. train_count = 0 train_loss2 = 0. train_dice2 = 0. for batch_idx, (inputs, augset, targets, targets1, targets2) in \ tqdm(enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size)): augoutput1 = [] augoutput2 = [] for aug_idx in range(augset['augno'][0]): augimg = augset['img{}'.format(aug_idx + 1)].to(device) augoutput1.append(net1(augimg).detach()) augoutput2.append(net2(augimg).detach()) augoutput1 = reverseaug(augset, augoutput1, classno=num_classes) augoutput2 = reverseaug(augset, augoutput2, classno=num_classes) for aug_idx in range(augset['augno'][0]): augmask1 = torch.nn.functional.softmax(augoutput1[aug_idx], dim=1) augmask2 = torch.nn.functional.softmax(augoutput2[aug_idx], dim=1) if aug_idx == 0: pseudo_label1 = augmask1 pseudo_label2 = augmask2 else: pseudo_label1 += augmask1 pseudo_label2 += augmask2 pseudo_label1 = pseudo_label1 / float(augset['augno'][0]) pseudo_label2 = pseudo_label2 / float(augset['augno'][0]) pseudo_label1 = sharpen(pseudo_label1, args.temperature) pseudo_label2 = sharpen(pseudo_label2, args.temperature) weightmap1 = 1.0 - 4.0 * pseudo_label1[:, 0, :, :] * pseudo_label1[:, 1, :, :] weightmap1 = weightmap1.unsqueeze(dim=1) weightmap2 = 1.0 - 4.0 * pseudo_label2[:, 0, :, :] * pseudo_label2[:, 1, :, :] weightmap2 = weightmap2.unsqueeze(dim=1) inputs = inputs.to(device) targets1 = targets1.to(device) targets2 = targets2.to(device) optimizer1.zero_grad() optimizer2.zero_grad() outputs1 = net1(inputs) outputs2 = net2(inputs) loss1_segpre = criterion(outputs1, targets2) loss2_segpre = criterion(outputs2, targets1) _, indx1 = loss1_segpre.sort() _, indx2 = loss2_segpre.sort() loss1_seg1 = criterion(outputs1[indx2[0:2], :, :, :], targets2[indx2[0:2], :, :]).mean() loss2_seg1 = criterion(outputs2[indx1[0:2], :, :, :], targets1[indx1[0:2], :, :]).mean() loss1_seg2 = criterion(outputs1[indx2[2:], :, :, :], targets2[indx2[2:], :, :]).mean() loss2_seg2 = criterion(outputs2[indx1[2:], :, :, :], targets1[indx1[2:], :, :]).mean() loss1_cor = weightmap2[indx2[2:], :, :, :] * corrlosscriterion( outputs1[indx2[2:], :, :, :], pseudo_label2[indx2[2:], :, :, :]) loss1_cor = loss1_cor.mean() loss1 = args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2) + \ args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor loss2_cor = weightmap1[indx1[2:], :, :, :] * corrlosscriterion( outputs2[indx1[2:], :, :, :], pseudo_label1[indx1[2:], :, :, :]) loss2_cor = loss2_cor.mean() loss2 = args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2) + \ args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor loss1.backward(retain_graph=True) optimizer1.step() loss2.backward() optimizer2.step() train_count += inputs.shape[0] train_loss1 += loss1.item() * inputs.shape[0] train_dice1 += Dice_fn(outputs1, targets2).item() train_loss2 += loss2.item() * inputs.shape[0] train_dice2 += Dice_fn(outputs2, targets1).item() train_loss1_epoch = train_loss1 / float(train_count) train_dice1_epoch = train_dice1 / float(train_count) train_loss2_epoch = train_loss2 / float(train_count) train_dice2_epoch = train_dice2 / float(train_count) print(rate_schedule[epoch]) print(args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2)) print(args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor) print(args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2)) print(args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor) # test net1.eval() net2.eval() test_loss1 = 0. test_dice1 = 0. test_loss2 = 0. test_dice2 = 0. test_count = 0 for batch_idx, (inputs, augset, targets, targets1, targets2) in \ tqdm(enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)): with torch.no_grad(): inputs = inputs.to(device) targets1 = targets1.to(device) targets2 = targets2.to(device) outputs1 = net1(inputs) outputs2 = net2(inputs) loss1 = criterion(outputs1, targets2).mean() loss2 = criterion(outputs2, targets1).mean() test_count += inputs.shape[0] test_loss1 += loss1.item() * inputs.shape[0] test_dice1 += Dice_fn(outputs1, targets2).item() test_loss2 += loss2.item() * inputs.shape[0] test_dice2 += Dice_fn(outputs2, targets1).item() test_loss1_epoch = test_loss1 / float(test_count) test_dice1_epoch = test_dice1 / float(test_count) test_loss2_epoch = test_loss2 / float(test_count) test_dice2_epoch = test_dice2 / float(test_count) testcasedices1 = torch.zeros(len(test_cases)) testcasedices2 = torch.zeros(len(test_cases)) startimgslices = torch.zeros(len(test_cases)) for casecount in tqdm(range(len(test_cases)), total=len(test_cases)): caseidx = test_cases[casecount] caseimg = [ file for file in test_dataset.imgs if caseidx.split('/')[-1].split('.')[0] in file ] caseimg.sort() casemask = [ file for file in test_dataset.masks if caseidx.split('/')[-1].split('.')[0] in file ] casemask.sort() generatedtarget1 = [] generatedtarget2 = [] target1 = [] target2 = [] startcaseimg = int(torch.sum(startimgslices[:casecount + 1])) for imgidx in range(len(caseimg)): assert caseimg[imgidx].split('/')[-1].split('.')[0] == \ casemask[imgidx].split('/')[-1].split('.')[0].split('_')[0] sample = test_dataset.__getitem__(imgidx + startcaseimg) input = sample[0] mask1 = sample[3] mask2 = sample[4] target1.append(mask1) target2.append(mask2) with torch.no_grad(): input = torch.unsqueeze(input.to(device), 0) output1 = net1(input) output1 = F.softmax(output1, dim=1) output1 = torch.argmax(output1, dim=1) output1 = output1.squeeze().cpu().numpy() generatedtarget1.append(output1) output2 = net2(input) output2 = F.softmax(output2, dim=1) output2 = torch.argmax(output2, dim=1) output2 = output2.squeeze().cpu().numpy() generatedtarget2.append(output2) target1 = np.stack(target1, axis=-1) target2 = np.stack(target2, axis=-1) generatedtarget1 = np.stack(generatedtarget1, axis=-1) generatedtarget2 = np.stack(generatedtarget2, axis=-1) generatedtarget1_keeplargest = keep_largest_connected_components( generatedtarget1) generatedtarget2_keeplargest = keep_largest_connected_components( generatedtarget2) testcasedices1[casecount] = Dice3d_fn(generatedtarget1_keeplargest, target1) testcasedices2[casecount] = Dice3d_fn(generatedtarget2_keeplargest, target2) if casecount + 1 < len(test_cases): startimgslices[casecount + 1] = len(caseimg) testcasedice1 = testcasedices1.sum() / float(len(test_cases)) testcasedice2 = testcasedices2.sum() / float(len(test_cases)) traincasedices1 = torch.zeros(len(train_cases)) traincasedices2 = torch.zeros(len(train_cases)) # update pseudolabel startimgslices = torch.zeros(len(train_cases)) generatedmask1 = [] generatedmask2 = [] for casecount in tqdm(range(len(train_cases)), total=len(train_cases)): caseidx = train_cases[casecount] caseimg = [ file for file in train_dataset.imgs if caseidx.split('/')[-1].split('.')[0] in file ] caseimg.sort() casemask = [ file for file in train_dataset.masks if caseidx.split('/')[-1].split('.')[0] in file ] casemask.sort() generatedtarget1 = [] generatedtarget2 = [] target1 = [] target2 = [] startcaseimg = int(torch.sum(startimgslices[:casecount + 1])) for imgidx in range(len(caseimg)): assert caseimg[imgidx].split('/')[-1].split('.')[0] == \ casemask[imgidx].split('/')[-1].split('.')[0].split('_')[0] sample = train_dataset.__getitem__(imgidx + startcaseimg) input = sample[0] mask1 = sample[3] mask2 = sample[4] target1.append(mask1) target2.append(mask2) with torch.no_grad(): input = torch.unsqueeze(input.to(device), 0) output1 = net1(input) output1 = F.softmax(output1, dim=1) output1 = torch.argmax(output1, dim=1) output1 = output1.squeeze().cpu().numpy() generatedtarget1.append(output1) output2 = net2(input) output2 = F.softmax(output2, dim=1) output2 = torch.argmax(output2, dim=1) output2 = output2.squeeze().cpu().numpy() generatedtarget2.append(output2) target1 = np.stack(target1, axis=-1) target2 = np.stack(target2, axis=-1) generatedtarget1 = np.stack(generatedtarget1, axis=-1) generatedtarget2 = np.stack(generatedtarget2, axis=-1) generatedtarget1_keeplargest = keep_largest_connected_components( generatedtarget1) generatedtarget2_keeplargest = keep_largest_connected_components( generatedtarget2) traincasedices1[casecount] = Dice3d_fn( generatedtarget1_keeplargest, target1) traincasedices2[casecount] = Dice3d_fn( generatedtarget2_keeplargest, target2) generatedmask1.append(generatedtarget1_keeplargest) generatedmask2.append(generatedtarget2_keeplargest) if casecount + 1 < len(train_cases): startimgslices[casecount + 1] = len(caseimg) traincasedice1 = traincasedices1.sum() / float(len(train_cases)) traincasedice2 = traincasedices2.sum() / float(len(train_cases)) traincasediceavgtemp = (traincasedice1 + traincasedice2) / 2.0 if traincasediceavgtemp > besttraincasedice: backfolder = os.path.join(train_root, tempmaskfolder + '_besttraindice') if os.path.exists(backfolder): shutil.rmtree(backfolder) shutil.copytree(os.path.join(train_root, tempmaskfolder), backfolder) besttraincasedice = traincasediceavgtemp logging.info('Best Checkpoint {} Saving...'.format(epoch + 1)) save_model = net1 if torch.cuda.device_count() > 1: save_model = list(net1.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss1_epoch, 'epoch': epoch + 1, } savecheckname = os.path.join( args.checkpoint, params1_name.split('.pkl')[0] + '_besttraincasedice.' + params1_name.split('.')[-1]) torch.save(state, savecheckname) save_model = net2 if torch.cuda.device_count() > 1: save_model = list(net2.children())[0] state = { 'net': save_model.state_dict(), 'loss': test_loss2_epoch, 'epoch': epoch + 1, } savecheckname = os.path.join( args.checkpoint, params2_name.split('.pkl')[0] + '_besttraincasedice.' + params2_name.split('.')[-1]) torch.save(state, savecheckname) if (epoch + 1) <= args.warmup_epoch or (epoch + 1) % 10 == 0: traincasediceavg = traincasediceavgtemp selected_samples = int(0.25 * (len(train_cases) - len(label_cases))) save_root = os.path.join(train_root, tempmaskfolder) _, sortidx1 = traincasedices1.sort() selectedidxs = sortidx1[:selected_samples] for selectedidx in selectedidxs: caseidx = train_cases[selectedidx] if caseidx not in label_cases: save_name = os.path.join( save_root, '{}_net1.{}'.format( train_masks[selectedidx].split('/')[-1].split('.') [0], train_masks[selectedidx].split('/')[-1].split( '.')[-1])) smasksave = sitk.GetImageFromArray( np.transpose(generatedmask1[selectedidx], [2, 0, 1])) sitk.WriteImage(smasksave, save_name) logging.info('Mask {} modify for net1'.format( [train_cases[i].split('/')[-1] for i in selectedidxs])) _, sortidx2 = traincasedices2.sort() selectedidxs = sortidx2[:selected_samples] for selectedidx in selectedidxs: caseidx = train_cases[selectedidx] if caseidx not in label_cases: save_name = os.path.join( save_root, '{}_net2.{}'.format( train_masks[selectedidx].split('/')[-1].split('.') [0], train_masks[selectedidx].split('/')[-1].split( '.')[-1])) smasksave = sitk.GetImageFromArray( np.transpose(generatedmask2[selectedidx], [2, 0, 1])) sitk.WriteImage(smasksave, save_name) logging.info('Mask {} modify for net2'.format( [train_cases[i].split('/')[-1] for i in selectedidxs])) time_cost = time.time() - ts logging.info( 'epoch[%d/%d]: train_loss1: %.3f | test_loss1: %.3f | train_dice1: %.3f | test_dice1: %.3f || ' 'traincase_dice1: %.3f || testcase_dice1: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss1_epoch, test_loss1_epoch, train_dice1_epoch, test_dice1_epoch, traincasedice1, testcasedice1, time_cost)) logging.info( 'epoch[%d/%d]: train_loss2: %.3f | test_loss2: %.3f | train_dice2: %.3f | test_dice2: %.3f || ' 'traincase_dice2: %.3f || testcase_dice2: %.3f || time: %.1f' % (epoch + 1, end_epoch, train_loss2_epoch, test_loss2_epoch, train_dice2_epoch, test_dice2_epoch, traincasedice2, testcasedice2, time_cost)) if args.lr_policy != 'None': scheduler1.step() scheduler2.step()