total_volume_path, total_label_path = util.gen_Pancreas_data_path( arg.data_path) train_volume_path, train_label_path, valid_volume_path, valid_label_path \ = util.divide_data2train_valid(total_volume_path, total_label_path, 62, 1997) valid_volume_path = [ i.replace('region', 'region2') for i in valid_volume_path ] valid_label_path = [ i.replace('region', 'region2') for i in valid_label_path ] model_multi_views = [] for i in range(3): model_view = UNet3D(1, 2, has_dropout=True).to("cuda:%d" % i) model_multi_views.append(model_view) for i in range(3): model_multi_views[i].load_state_dict( torch.load("%s/model_view%d_%d.ckpt" % (arg.model_path, i + 1, arg.iter_from))) model_multi_views[i].eval() valid_dataset = Pancreas(valid_volume_path, valid_label_path, transform_views=[ transforms.Compose([ CenterCrop((128, 128, 128)), Rotate((0, 0)), ToTensor()
def train(gpu, batch_size, labeled_bs, seed1, seed2, iter_from, n_total_iter_from, n_epochs, lr, n_save_iter, data_path, model_dir_root_path, model_pre_trained_dir, note): """ Training 3D U-Net :param gpu: gpu id :param batch_size: batch size :param labeled_bs: labeled batch size :param seed1: seed 1 :param seed2: seed 2 :param iter_from: iter_from to start training from :param n_total_iter_from: used for continuing training :param n_epochs: number of training epochs :param lr: learning rate :param n_save_iter: Determines how many epochs before saving model version :param data_path: data path :param model_dir_root_path: the model directory root path to save to :param model_pre_trained_dir: Path to pre-trained model :param note: :return: """ """ setting """ # gpu os.environ["CUDA_VISIBLE_DEVICES"] = gpu # time now = time.localtime() now_format = time.strftime("%Y-%m-%d %H:%M:%S", now) # time format date_now = now_format.split(' ')[0] time_now = now_format.split(' ')[1] # save model path save_path = os.path.join(model_dir_root_path, date_now, time_now) if not os.path.exists(save_path): os.makedirs(save_path) # print setting print("----------------------------------setting-------------------------------------") print("lr:%f" % lr) if model_pre_trained_dir is None: print("pre-trained dir is None") else: print("pre-trained dir:%s" % model_pre_trained_dir) print("path of saving model:%s" % save_path) print("----------------------------------setting-------------------------------------") # save parameters to TXT. parameter_dict = {"gpu": gpu, "model_pre_trained_dir": model_pre_trained_dir, "iter_from": iter_from, "lr": lr, "save_path": save_path, 'note': note} txt_name = 'parameter_log.txt' path = os.path.join(save_path, txt_name) with codecs.open(path, mode='a', encoding='utf-8') as file_txt: for key, value in parameter_dict.items(): file_txt.write(str(key) + ':' + str(value) + '\n') # logging logging.basicConfig(filename=save_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(parameter_dict) # tensorboardX writer = SummaryWriter(log_dir=save_path) # label_dict label_list = [0, 1] """ data generator """ # load all data path total_volume_path, total_label_path = util.gen_Pancreas_data_path(data_path) # 80 data -> 60 data for training # 80 data -> 20 data for validation train_volume_path, train_label_path, valid_volume_path, valid_label_path \ = util.divide_data2train_valid(total_volume_path, total_label_path, 60, seed1) # dataset # training labeled_idxs = list(range(12)) unlabeled_idxs = list(range(12, 60)) batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size - labeled_bs) train_dataset = Pancreas(train_volume_path, train_label_path, transform_views=[transforms.Compose([RandomNoise(), Rotate(0), ToTensor()]), transforms.Compose([RandomNoise(), Rotate(180), ToTensor()])]) # validation # valid_dataset = Pancreas(valid_volume_path, valid_label_path, # transform=transforms.Compose([ToTensor])) # dataloader def worker_init_fn(worker_id): random.seed(seed1 + worker_id) train_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn) """ model, optimizer, loss """ model_view1 = UNet3D(1, 2, has_dropout=True).to("cuda:0") model_view2 = UNet3D(1, 2, has_dropout=True).to("cuda:1") # model_view3 = UNet3D(1, 2).cuda() if iter_from != 0: model_view1.load_state_dict(torch.load("%s/model_view1_%d.ckpt" % (model_pre_trained_dir, iter_from))) model_view2.load_state_dict(torch.load("%s/model_view2_%d.ckpt" % (model_pre_trained_dir, iter_from))) # model_view3.load_state_dict(torch.load("%s/model_view3_%d.ckpt" % (model_pre_trained_dir, iter_from))) optimizer_view1 = torch.optim.Adam(model_view1.parameters(), lr=lr) optimizer_view2 = torch.optim.Adam(model_view2.parameters(), lr=lr) # optimizer_view3 = torch.optim.Adam(model_view3.parameters(), lr=lr) criterion1 = nn.CrossEntropyLoss() criterion2 = dice_loss """ training loop """ # TODO: do label fusion (T^-1) n_total_iter = 0 if n_total_iter_from != 0: n_total_iter = n_total_iter_from for epoch in range(n_epochs): for batch_index, sample_batch_views in enumerate(train_dataloader): # start_time start = time.time() # loading data for i, sample_batch in enumerate(sample_batch_views): # sample_batch_name = sample_batch['name'] if i == 0: device = "cuda:0" else: device = "cuda:1" sample_batch['volume'] = sample_batch['volume'].to(device).float() sample_batch['label'] = sample_batch['label'].to(device).float() # labeled data(volume, label) labeled_volume_batch_view1 = sample_batch_views[0]['volume'][:labeled_bs] labeled_volume_batch_view2 = sample_batch_views[1]['volume'][:labeled_bs] label_batch_view1 = sample_batch_views[0]['label'][:labeled_bs] label_batch_view2 = sample_batch_views[1]['label'][:labeled_bs] # unlabeled data (volume) unlabeled_volume_batch_view1 = sample_batch_views[0]['volume'][labeled_bs:] unlabeled_volume_batch_view2 = sample_batch_views[1]['volume'][labeled_bs:] # put noise into unlabeled data noise1 = torch.clamp(torch.rand_like(unlabeled_volume_batch_view1)*0.1, -0.2, 0.2).to('cuda:0') noise2 = torch.clamp(torch.rand_like(unlabeled_volume_batch_view1) * 0.1, -0.2, 0.2).to('cuda:1') # noise_unlabeled_volume_batch_view1 = unlabeled_volume_batch_view1 + noise1 # noise_unlabeled_volume_batch_view2 = unlabeled_volume_batch_view2 + noise2 # ------------------ # Train model # ------------------ # zeros the parameter gradients optimizer_view1.zero_grad() optimizer_view2.zero_grad() # run 3D U-Net model on labeled data with view1 & view2 pred_labeled_view1 = model_view1(labeled_volume_batch_view1) pred_labeled_view2 = model_view2(labeled_volume_batch_view2) # run 3D U-Net model on unlabeled data with view1 & view2 (Bayesian) T = 8 unlabeled_volume_batch_r_view1 = unlabeled_volume_batch_view1.repeat(2, 1, 1, 1, 1) unlabeled_volume_batch_r_view2 = unlabeled_volume_batch_view2.repeat(2, 1, 1, 1, 1) stride = unlabeled_volume_batch_r_view1.shape[0] // 2 pred_unlabeled_view1 = torch.zeros([stride * T, 2, 128, 128, 128]).to('cuda:0') pred_unlabeled_view2 = torch.zeros([stride * T, 2, 128, 128, 128]).to('cuda:1') for i in range(T//2): noise_unlabeled_volume_batch_view1 = unlabeled_volume_batch_r_view1 \ + torch.clamp(torch.rand_like(unlabeled_volume_batch_r_view1) * 0.1, -0.2, 0.2).to('cuda:0') noise_unlabeled_volume_batch_view2 = unlabeled_volume_batch_r_view2 \ + torch.clamp(torch.rand_like(unlabeled_volume_batch_r_view1) * 0.1, -0.2, 0.2).to('cuda:1') with torch.no_grad(): pred_unlabeled_view1[2*stride*i:2*stride*(i+1)] = model_view1(noise_unlabeled_volume_batch_view1) pred_unlabeled_view2[2*stride*i:2*stride*(i+1)] = model_view2(noise_unlabeled_volume_batch_view2) pred_unlabeled_view1 = F.softmax(pred_unlabeled_view1, dim=1) pred_unlabeled_view1 = pred_unlabeled_view1.reshape(T, stride, 2, 128, 128, 128) pred_unlabeled_view1 = torch.mean(pred_unlabeled_view1, dim=0) # (batch, 2, 128, 128, 128) uncertainty_view1 = -1.0*torch.sum(pred_unlabeled_view1*torch.log(pred_unlabeled_view1 + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128, 128) pred_unlabeled_view2 = F.softmax(pred_unlabeled_view2, dim=1) pred_unlabeled_view2 = pred_unlabeled_view2.reshape(T, stride, 2, 128, 128, 128) pred_unlabeled_view2 = torch.mean(pred_unlabeled_view2, dim=0) # (batch, 2, 128, 128, 128) uncertainty_view2 = -1.0 * torch.sum(pred_unlabeled_view2 * torch.log(pred_unlabeled_view2 + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128, 128) # TODO: label fusion(先要转为同一视角) # label_fusion_view1 = # Calculate loss label_batch_stand_view1 = util.standardized_seg(label_batch_view1, label_list, "cuda:0") label_batch_stand_view2 = util.standardized_seg(label_batch_view2, label_list, "cuda:1") loss_1_view1 = criterion1(pred_labeled_view1, label_batch_stand_view1) loss_1_view2 = criterion1(pred_labeled_view2, label_batch_stand_view2) label_batch_one_hot_view1 = util.onehot(label_batch_view1, label_list, "cuda:0") label_batch_one_hot_view2 = util.onehot(label_batch_view2, label_list, "cuda:1") pred_labeled_softmax_view1 = F.softmax(pred_labeled_view1, dim=1) pred_labeled_softmax_view2 = F.softmax(pred_labeled_view2, dim=1) loss_2_view1 = criterion2(pred_labeled_softmax_view1, label_batch_one_hot_view1) loss_2_view2 = criterion2(pred_labeled_softmax_view2, label_batch_one_hot_view2) loss_view1 = 0.5 * (loss_1_view1 + loss_2_view1) loss_view2 = 0.5 * (loss_1_view2 + loss_2_view2) # backwards and optimize loss_view1.backward() optimizer_view1.step() loss_view2.backward() optimizer_view2.step() # --------------------- # Print log # --------------------- # Determine approximate time left end = time.time() iter_left = (n_epochs - epoch) * (len(train_dataloader) - batch_index) time_left = datetime.timedelta(seconds=iter_left * (end - start)) # print log logging.info("[Epoch: %4d/%d] [n_total_iter: %5d] [Total index: %2d/%d] " "[loss view1: %f] [loss view2: %f] [ETA: %s]" % (epoch, n_epochs, n_total_iter+1, batch_index+1, len(train_dataloader), loss_view1.item(), loss_view2.item(), time_left)) # tensorboardX log writer writer.add_scalar("loss_view1/loss", loss_view1.item(), global_step=n_total_iter) writer.add_scalar("loss_view1/loss_CrossEntropy", loss_1_view1.item(), global_step=n_total_iter) writer.add_scalar("loss_view1/loss_Dice", loss_2_view1.item(), global_step=n_total_iter) writer.add_scalar("loss_view2/loss", loss_view2.item(), global_step=n_total_iter) writer.add_scalar("loss_view2/loss_CrossEntropy", loss_1_view2.item(), global_step=n_total_iter) writer.add_scalar("loss_view2/loss_Dice", loss_2_view2.item(), global_step=n_total_iter) if n_total_iter % n_save_iter == 0: # Save model checkpoints torch.save(model_view1.state_dict(), "%s/model_view1_%d.ckpt" % (save_path, n_total_iter)) torch.save(model_view2.state_dict(), "%s/model_view2_%d.ckpt" % (save_path, n_total_iter)) logging.info("save model : %s/model_view1_%d.ckpt" % (save_path, n_total_iter)) logging.info("save model : %s/model_view2_%d.ckpt" % (save_path, n_total_iter)) n_total_iter += 1 torch.save(model_view1.state_dict(), "%s/model_view1_%d.ckpt" % (save_path, n_total_iter)) torch.save(model_view2.state_dict(), "%s/model_view2_%d.ckpt" % (save_path, n_total_iter)) logging.info("save model : %s/model_view1_%d.ckpt" % (save_path, n_total_iter)) logging.info("save model : %s/model_view2_%d.ckpt" % (save_path, n_total_iter)) writer.close()
def train(gpu, batch_size, seed1, seed2, iter_from, n_total_iter_from, n_epochs, lr, n_save_iter, data_path, model_dir_root_path, model_pre_trained_dir, note): """ Training 3D U-Net :param gpu: gpu id :param batch_size: batch size :param seed1: seed 1 :param seed2: seed 2 :param iter_from: iter_from to start training from :param n_total_iter_from: used for continuing training :param n_epochs: number of training epochs :param lr: learning rate :param n_save_iter: Determines how many epochs before saving model version :param data_path: data path :param model_dir_root_path: the model directory root path to save to :param model_pre_trained_dir: Path to pre-trained model :param note: :return: """ """ setting """ # gpu os.environ["CUDA_VISIBLE_DEVICES"] = gpu # time now = time.localtime() now_format = time.strftime("%Y-%m-%d %H:%M:%S", now) # time format date_now = now_format.split(' ')[0] time_now = now_format.split(' ')[1] # save model path save_path = os.path.join(model_dir_root_path, date_now, time_now) if not os.path.exists(save_path): os.makedirs(save_path) # print setting print( "----------------------------------setting-------------------------------------" ) print("lr:%f" % lr) if model_pre_trained_dir is None: print("pre-trained dir is None") else: print("pre-trained dir:%s" % model_pre_trained_dir) print("path of saving model:%s" % save_path) print( "----------------------------------setting-------------------------------------" ) # save parameters to TXT. parameter_dict = { "gpu": gpu, "batch size": batch_size, "model_pre_trained_dir": model_pre_trained_dir, "data path": data_path, "iter_from": iter_from, "lr": lr, "save_path": save_path, 'note': note } txt_name = 'parameter_log.txt' path = os.path.join(save_path, txt_name) with codecs.open(path, mode='a', encoding='utf-8') as file_txt: for key, value in parameter_dict.items(): file_txt.write(str(key) + ':' + str(value) + '\n') # logging logging.basicConfig(filename=save_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(parameter_dict) # tensorboardX writer = SummaryWriter(log_dir=save_path) # label_dict label_list = [0, 1] # patch size patch_size = (128, 128, 128) """ data generator """ # load all data path total_volume_path, total_label_path = util.gen_Pancreas_data_path( data_path) # 82 data -> 62 data for training # 82 data -> 20 data for validation train_volume_path, train_label_path, valid_volume_path, valid_label_path \ = util.divide_data2train_valid(total_volume_path, total_label_path, 62, seed1) train_volume_path = [ i.replace('region', 'region2') for i in train_volume_path ] train_label_path = [ i.replace('region', 'region2') for i in train_label_path ] # dataset # training # todo: validation train_volume_path.append(train_volume_path[-1]) train_label_path.append(train_label_path[-1]) total_dataset = Pancreas(train_volume_path, train_label_path, transform=transforms.Compose([ RandomCrop(patch_size), RandomNoise(), ToTensor() ])) # validation valid_dataset = Pancreas(valid_volume_path, valid_label_path, transform=transforms.Compose([ToTensor])) # dataloader def worker_init_fn(worker_id): random.seed(seed1 + worker_id) total_dataloader = DataLoader(total_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn) """ model, optimizer, loss """ modelS = UNet3D(1, 2, has_dropout=True).cuda() if iter_from != 0: modelS.load_state_dict( torch.load("%s/modelS_%d.ckpt" % (model_pre_trained_dir, iter_from))) optimizer_S = torch.optim.Adam(modelS.parameters(), lr=lr) criterion1 = nn.CrossEntropyLoss() criterion2 = dice_loss """ training loop """ n_total_iter = 0 if n_total_iter_from != 0: n_total_iter = n_total_iter_from for epoch in range(n_epochs): for total_index, total_sample in enumerate(total_dataloader): # start_time start = time.time() # generate moving data total_name = total_sample['name'] total_input = total_sample['volume'].to('cuda').float() total_seg = total_sample['label'].to('cuda').float() # ------------------ # Train model # ------------------ # zeros the parameter gradients optimizer_S.zero_grad() # run 3D U-Net model pred = modelS(total_input) # Calculate loss # todo : if wrong, check label_list total_seg_stand = util.standardized_seg(total_seg, label_list) loss_1 = criterion1(pred, total_seg_stand) total_seg_one_hot = util.onehot(total_seg, label_list) pred_softmax = F.softmax(pred, dim=1) loss_2 = criterion2(pred_softmax[:, 1, :, :, :], total_seg_one_hot[:, 1, :, :, :]) loss = 0.5 * (loss_1 + loss_2) # backwards and optimize loss.backward() optimizer_S.step() # --------------------- # Print log # --------------------- # Determine approximate time left end = time.time() iter_left = (n_epochs - epoch) * (len(total_dataloader) - total_index) time_left = datetime.timedelta(seconds=iter_left * (end - start)) # print log logging.info( "[Epoch: %4d/%d] [n_total_iter: %5d] [Total index: %2d/%d] [loss: %f] [ETA: %s]" % (epoch, n_epochs, n_total_iter + 1, total_index + 1, len(total_dataloader), loss.item(), time_left)) # tensorboardX log writer writer.add_scalar("loss/loss", loss.item(), global_step=n_total_iter) writer.add_scalar("loss/loss_CrossEntropy", loss_1.item(), global_step=n_total_iter) writer.add_scalar("loss/loss_Dice", loss_2.item(), global_step=n_total_iter) if n_total_iter % n_save_iter == 0: # Save model checkpoints torch.save(modelS.state_dict(), "%s/modelS_%d.ckpt" % (save_path, n_total_iter)) logging.info("save model : %s/modelS_%d.ckpt" % (save_path, n_total_iter)) n_total_iter += 1 torch.save(modelS.state_dict(), "%s/modelS_%d.ckpt" % (save_path, n_total_iter)) logging.info("save model : %s/modelS_%d.ckpt" % (save_path, n_total_iter)) writer.close()
total_volume_path, total_label_path = util.gen_Pancreas_data_path( arg.data_path) total_volume_path = [ i.replace('region', 'region2') for i in total_volume_path ] total_label_path = [ i.replace('region', 'region2') for i in total_label_path ] train_volume_path, train_label_path, valid_volume_path, valid_label_path \ = util.divide_data2train_valid(total_volume_path, total_label_path, 62, 1997) option = arg.option valid_options = ['baseline', 'upper bound', 'multi-views'] if option == valid_options[0]: modelS = UNet3D(1, 2).cuda() valid_dataset = Pancreas(valid_volume_path, valid_label_path, transform=transforms.Compose([ CenterCrop((128, 128, 128)), Rotate((1, 2)), ToTensor() ])) # valid_dataset = Pancreas(valid_volume_path, valid_label_path, transform=transforms.Compose([Rotate((0,2)), ToTensor()])) # valid_dataset = Pancreas(valid_volume_path, valid_label_path, transform=transforms.Compose([Rotate((1,2)), ToTensor()])) # train_dataset = Pancreas(train_volume_path[0:12], train_label_path[0:12], transform=transforms.Compose([Rotate((0,2)), ToTensor()])) modelS.load_state_dict( torch.load("%s/modelS_%d.ckpt" % (arg.model_path, arg.iter_from))) save_dir = '../data/UNet3D_pred_baseline/' + arg.model_path.split('UNet3D/')[-1] + \ '/' + str(arg.iter_from) + '/' print('\n\n baseline testing .........')
def train(gpu, batch_size_stage1, batch_size_stage2, labeled_bs, view_num, n_iter_view_wise, seed1, seed2, iter_from, n_total_iter_from, n_epochs, lr_stage1, lr_stage2, n_save_iter, data_path, model_dir_root_path, model_pre_trained_dir, note): """ Training 3D U-Net :param gpu: gpu id :param batch_size_stage1: batch size for stage 1 :param batch_size_stage2: batch size for stage 2 :param labeled_bs: labeled batch size :param view_num: number of views :param n_iter_view_wise: number of iterations for view-wise training :param seed1: seed 1 :param seed2: seed 2 :param iter_from: iter_from to start training from :param n_total_iter_from: used for continuing training :param n_epochs: number of training epochs :param lr_stage1: for view_wise training :param lr_stage2: for co-training :param n_save_iter: Determines how many epochs before saving model version :param data_path: data path :param model_dir_root_path: the model directory root path to save to :param model_pre_trained_dir: Path to pre-trained model :param note: :return: """ """ setting """ # gpu os.environ["CUDA_VISIBLE_DEVICES"] = gpu # time now = time.localtime() now_format = time.strftime("%Y-%m-%d %H:%M:%S", now) # time format date_now = now_format.split(' ')[0] time_now = now_format.split(' ')[1] # save model path save_path = os.path.join(model_dir_root_path, date_now, time_now) if not os.path.exists(save_path): os.makedirs(save_path) # print setting print( "----------------------------------setting-------------------------------------" ) print("lr for stage1 :%f" % lr_stage1) print("lr for stage2 :%f" % lr_stage2) if model_pre_trained_dir is None: print("pre-trained dir is None") else: print("pre-trained dir:%s" % model_pre_trained_dir) print("path of saving model:%s" % save_path) print( "----------------------------------setting-------------------------------------" ) # save parameters to TXT. parameter_dict = { "gpu": gpu, "model_pre_trained_dir": model_pre_trained_dir, "iter_from": iter_from, "lr_stage1": lr_stage1, "lr_stage2": lr_stage2, "save_path": save_path, 'note': note } txt_name = 'parameter_log.txt' path = os.path.join(save_path, txt_name) with codecs.open(path, mode='a', encoding='utf-8') as file_txt: for key, value in parameter_dict.items(): file_txt.write(str(key) + ':' + str(value) + '\n') # logging logging.basicConfig(filename=save_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(parameter_dict) # tensorboardX writer = SummaryWriter(log_dir=save_path) # label_dict label_list = [0, 1] """ data generator """ # load all data path total_volume_path, total_label_path = util.gen_Pancreas_data_path( data_path) # 82 data -> 62 data for training # 82 data -> 20 data for validation train_volume_path, train_label_path, valid_volume_path, valid_label_path \ = util.divide_data2train_valid(total_volume_path, total_label_path, 62, seed1) # dataset # training labeled_idxs = list(range(12)) unlabeled_idxs = list(range(12, 62)) batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size_stage2, batch_size_stage2 - labeled_bs) train_stage1_dataset = Pancreas( train_volume_path[0:12], train_label_path[0:12], transform_views=[ transforms.Compose([RandomNoise(), ToTensor()]), transforms.Compose([RandomNoise(), Rotate((0, 2)), ToTensor()]), transforms.Compose([RandomNoise(), Rotate((1, 2)), ToTensor()]) ]) train_stage2_dataset = Pancreas( train_volume_path, train_label_path, transform_views=[ transforms.Compose([RandomNoise(), ToTensor()]), transforms.Compose([RandomNoise(), Rotate((0, 2)), ToTensor()]), transforms.Compose([RandomNoise(), Rotate((1, 2)), ToTensor()]) ]) # validation # valid_dataset = Pancreas(valid_volume_path, valid_label_path, # transform=transforms.Compose([ToTensor])) # dataloader def worker_init_fn(worker_id): random.seed(seed1 + worker_id) train_stage1_dataloader = DataLoader(train_stage1_dataset, batch_size=batch_size_stage1, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) train_stage2_dataloader = DataLoader(train_stage2_dataset, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) """ model, optimizer, loss """ model_multi_views = [] for i in range(view_num): model_view = UNet3D(1, 2, has_dropout=True).to("cuda:%d" % i) model_multi_views.append(model_view) if iter_from != 0: for i in range(view_num): model_multi_views[i].load_state_dict( torch.load("%s/model_view%d_%d.ckpt" % (model_pre_trained_dir, i + 1, iter_from))) optimizer_stage1_multi_views = [] optimizer_stage2_multi_views = [] for i in range(view_num): # view-wise training stage, lr=7e-3, m=0.9, weight decay = 4e-5, iterations = 20k (原文) optimizer_stage1_view = torch.optim.SGD( model_multi_views[i].parameters(), lr=lr_stage1, momentum=0.9, weight_decay=4e-5) # optimizer_stage1_view = torch.optim.Adam(model_multi_views[i].parameters(), lr=lr_stage1) optimizer_stage1_multi_views.append(optimizer_stage1_view) # co-training stage, constant lr=1e-3, iterations = 5k optimizer_stage2_view = torch.optim.SGD( model_multi_views[i].parameters(), lr=lr_stage2) optimizer_stage2_multi_views.append(optimizer_stage2_view) criterion1 = nn.CrossEntropyLoss() criterion2 = dice_loss n_total_iter = 0 if n_total_iter_from != 0: n_total_iter = n_total_iter_from for epoch in range(n_epochs): if n_total_iter < n_iter_view_wise: """ view-wise training """ for batch_index, sample_batch_views in enumerate( train_stage1_dataloader): # start_time start = time.time() # ------------------ # Loading data # ------------------ for i, sample_batch in enumerate(sample_batch_views): sample_batch['volume'] = sample_batch['volume'].to( "cuda:%d" % i).float() sample_batch['label'] = sample_batch['label'].to( "cuda:%d" % i).float() # labeled data(volume, label) # [view_num * (bs, 1, 128, 128, 128)] labeled_volume_batch_views = [ sample_batch_views[i]['volume'] for i in range(view_num) ] # [view_num * (bs, 1, 128, 128, 128)] label_batch_views = [ sample_batch_views[i]['label'] for i in range(view_num) ] # ------------------ # Train model # ------------------ # zeros the parameter gradients for i in range(view_num): optimizer_stage1_multi_views[i].zero_grad() # run 3D U-Net model on labeled data with multi-views # [view_num * (bs, 2, 128, 128, 128)] pred_labeled_views = [ model_multi_views[i](labeled_volume_batch_views[i]) for i in range(view_num) ] pred_labeled_softmax_views = [ F.softmax(pred_labeled_views[i], dim=1) for i in range(view_num) ] # supervised loss loss_sup1_views = [] loss_sup2_views = [] loss_sup_total_views = [] for i in range(view_num): label_batch_stand_view = util.standardized_seg( label_batch_views[i], label_list, "cuda:%d" % i) loss_sup1 = criterion1(pred_labeled_views[i], label_batch_stand_view) loss_sup1_views.append(loss_sup1) label_batch_one_hot_view = util.onehot( label_batch_views[i], label_list, "cuda:%d" % i) loss_sup2 = criterion2( pred_labeled_softmax_views[i][:, 1, :, :, :], label_batch_one_hot_view[:, 1, :, :, :]) loss_sup2_views.append(loss_sup2) loss_sup_total_views.append(0.5 * loss_sup1 + 0.5 * loss_sup2) for i in range(view_num): loss_sup_total_views[i].backward() optimizer_stage1_multi_views[i].step() # --------------------- # Print log # --------------------- # logging, tensorboard # Determine approximate time left end = time.time() iter_left = (n_epochs - epoch) * ( len(train_stage1_dataloader) - batch_index) time_left = datetime.timedelta(seconds=iter_left * (end - start)) # print log logging.info( "[Epoch: %4d/%d] [n_total_iter: %5d] [labeled volume index: %2d/%d] " "[loss view1: %f] [loss view2: %f] [loss view3: %f] [ETA: %s]" % (epoch, n_epochs, n_total_iter + 1, batch_index + 1, len(train_stage1_dataloader), loss_sup_total_views[0].item(), loss_sup_total_views[1].item(), loss_sup_total_views[2].item(), time_left)) # tensorboardX log writer for i in range(view_num): writer.add_scalar("loss_view%d/loss_sup_total" % i, loss_sup_total_views[i].item(), global_step=n_total_iter) writer.add_scalar("loss_view%d/loss_sup_CrossEntropy" % i, loss_sup1_views[i].item(), global_step=n_total_iter) writer.add_scalar("loss_view%d/loss_sup_Dice" % i, loss_sup2_views[i].item(), global_step=n_total_iter) # --------------------- # Save model # --------------------- if n_total_iter % n_save_iter == 0: # Save model checkpoints for i in range(view_num): torch.save( model_multi_views[i].state_dict(), "%s/model_view%d_%d.ckpt" % (save_path, i + 1, n_total_iter)) logging.info("save model : %s/model_view%d_%d.ckpt" % (save_path, i + 1, n_total_iter)) n_total_iter += 1 else: """ co-training """ for batch_index, sample_batch_views in enumerate( train_stage2_dataloader): # start_time start = time.time() # ------------------ # Loading data # ------------------ for i, sample_batch in enumerate(sample_batch_views): sample_batch['volume'] = sample_batch['volume'].to( "cuda:%d" % i).float() sample_batch['label'] = sample_batch['label'].to( "cuda:%d" % i).float() # labeled data(volume, label) # [view_num * (bs, 1, 128, 128, 128)] labeled_volume_batch_views = [ sample_batch_views[i]['volume'][:labeled_bs] for i in range(view_num) ] # [view_num * (bs, 1, 128, 128, 128)] label_batch_views = [ sample_batch_views[i]['label'][:labeled_bs] for i in range(view_num) ] # unlabeled data (volume) # [view_num * (bs, 1, 128, 128, 128)] unlabeled_volume_batch_views = [ sample_batch_views[i]['volume'][labeled_bs:] for i in range(view_num) ] # ------------------ # Train model # ------------------ # zeros the parameter gradients for i in range(view_num): optimizer_stage2_multi_views[i].zero_grad() # run 3D U-Net model on labeled data with multi-views # [view_num * (bs, 2, 128, 128, 128)] pred_labeled_views = [ model_multi_views[i](labeled_volume_batch_views[i]) for i in range(view_num) ] pred_labeled_softmax_views = [ F.softmax(pred_labeled_views[i], dim=1) for i in range(view_num) ] # run 3D U-Net model on unlabeled data with multi-views pred_unlabeled_views = [ model_multi_views[i](unlabeled_volume_batch_views[i]) for i in range(view_num) ] pred_unlabeled_softmax_views = [ F.softmax(pred_unlabeled_views[i], dim=1) for i in range(view_num) ] # run 3D U-Net bayesian model on unlabeled data with multi-views T = 8 unlabeled_volume_batch_repeat_views = [ unlabeled_volume_batch_views[i].repeat(2, 1, 1, 1, 1) for i in range(view_num) ] stride = unlabeled_volume_batch_repeat_views[0].shape[0] // 2 pred_unlabeled_bayes_views = [ torch.zeros([stride * T, 2, 128, 128, 128]).to('cuda:%d' % i) for i in range(view_num) ] for t in range(T // 2): noise_unlabeled_volume_views = [ unlabeled_volume_batch_repeat_views[i] + torch.clamp( torch.rand_like( unlabeled_volume_batch_repeat_views[i]) * 0.1, -0.2, 0.2).to('cuda:%d' % i) for i in range(view_num) ] with torch.no_grad(): for i in range(view_num): pred_unlabeled_bayes_views[ i][2 * stride * t:2 * stride * (t + 1)] = model_multi_views[i]( noise_unlabeled_volume_views[i]) pred_unlabeled_bayes_views = [ F.softmax(pred_unlabeled_bayes_views[i], dim=1) for i in range(view_num) ] pred_unlabeled_bayes_views = [ pred_unlabeled_bayes_views[i].reshape( T, stride, 2, 128, 128, 128) for i in range(view_num) ] pred_unlabeled_bayes_views = [ torch.mean(pred_unlabeled_bayes_views[i], dim=0) for i in range(view_num) ] # [view_num * (bs, 2, 128, 128, 128)] uncertainty_views = [ -1.0 * torch.sum(pred_unlabeled_bayes_views[i] * torch.log(pred_unlabeled_bayes_views[i] + 1e-6), dim=1, keepdim=True) for i in range(view_num) ] # [view_num * (bs, 1, 128, 128, 128)] # turn uncertainty and pred_unlabeled into the same view rotate_axes = [(0, 2), (1, 2)] for i in range(view_num): if i == 0: continue else: pred_unlabeled_softmax_views[ i] = pred_unlabeled_softmax_views[i].rot90( dims=rotate_axes[i + 1], k=-1) pred_unlabeled_views[i] = pred_unlabeled_views[ i].rot90(dims=rotate_axes[i + 1], k=-1) uncertainty_views[i] = uncertainty_views[i].rot90( dims=rotate_axes[i + 1], k=-1) # label fusion pseudo_label_views = [] for i in range(view_num): top = torch.zeros_like(pred_unlabeled_softmax_views[i]).to( 'cuda:%d' % i) down = torch.zeros_like(uncertainty_views[i]).to( 'cuda:%d' % i) for j in range(view_num): if i == j: continue else: top += (pred_unlabeled_softmax_views[j] / uncertainty_views[j]).to('cuda:%d' % i) down += 1 / uncertainty_views[j].to('cuda:%d' % i) pseudo_label = top / down pseudo_label = torch.argmax(pseudo_label, dim=1) pseudo_label_views.append( pseudo_label) # [view_num * (bs, 128, 128, 128)] # supervised loss loss_sup1_views = [] loss_sup2_views = [] loss_sup_total_views = [] for i in range(view_num): label_batch_stand_view = util.standardized_seg( label_batch_views[i], label_list, "cuda:%d" % i) loss_sup1 = criterion1(pred_labeled_views[i], label_batch_stand_view) loss_sup1_views.append(loss_sup1) label_batch_one_hot_view = util.onehot( label_batch_views[i], label_list, "cuda:%d" % i) loss_sup2 = criterion2( pred_labeled_softmax_views[i][:, 1, :, :, :], label_batch_one_hot_view[:, 1, :, :, :]) loss_sup2_views.append(loss_sup2) loss_sup_total_views.append(0.5 * loss_sup1 + 0.5 * loss_sup2) # co-training loss loss_cot1_views = [] loss_cot2_views = [] loss_cot_total_views = [] for i in range(view_num): loss_cot1 = criterion1(pred_unlabeled_views[i], pseudo_label_views[i]) loss_cot1_views.append(loss_cot1) pseudo_label_view = pseudo_label_views[i].unsqueeze( dim=1) # (bs, 1, 128, 128, 128) pseudo_label_one_hot_view = util.onehot( pseudo_label_view, label_list, "cuda:%d" % i) loss_cot2 = criterion2( pred_unlabeled_softmax_views[i][:, 1, :, :, :], pseudo_label_one_hot_view[:, 1, :, :, :]) loss_cot2_views.append(loss_cot2) loss_cot_total_views.append(0.5 * loss_cot1 + 0.5 * loss_cot2) # backwards and optimize loss_total_views = [ loss_sup_total_views[i] + 0.2 * loss_cot_total_views[i] for i in range(view_num) ] for i in range(view_num): loss_total_views[i].backward() optimizer_stage2_multi_views[i].step() # --------------------- # Print log # --------------------- # logging, tensorboard # Determine approximate time left end = time.time() iter_left = (n_epochs - epoch) * ( len(train_stage2_dataloader) - batch_index) time_left = datetime.timedelta(seconds=iter_left * (end - start)) # print log logging.info( "[Epoch: %4d/%d] [n_total_iter: %5d] [labeled volume index: %2d/%d] " "[loss view1: %f] [loss view2: %f] [loss view3: %f][ETA: %s]" % (epoch, n_epochs, n_total_iter + 1, batch_index + 1, len(train_stage2_dataloader), loss_total_views[0].item(), loss_total_views[1].item(), loss_total_views[2].item(), time_left)) # tensorboardX log writer for i in range(view_num): writer.add_scalar("loss_view%d/loss_SupCot_total" % i, loss_total_views[i].item(), global_step=n_total_iter) writer.add_scalar("loss_view%d/loss_sup_total" % i, loss_sup_total_views[i].item(), global_step=n_total_iter) writer.add_scalar("loss_view%d/loss_sup_CrossEntropy" % i, loss_sup1_views[i].item(), global_step=n_total_iter) writer.add_scalar("loss_view%d/loss_sup_Dice" % i, loss_sup2_views[i].item(), global_step=n_total_iter) writer.add_scalar("loss_view%d/loss_cot_total" % i, loss_cot_total_views[i].item(), global_step=n_total_iter) writer.add_scalar("loss_view%d/loss_cot_CrossEntropy" % i, loss_cot1_views[i].item(), global_step=n_total_iter) writer.add_scalar("loss_view%d/loss_cot_Dice" % i, loss_cot2_views[i].item(), global_step=n_total_iter) # --------------------- # save model # --------------------- if n_total_iter % n_save_iter == 0: # Save model checkpoints for i in range(view_num): torch.save( model_multi_views[i].state_dict(), "%s/model_view%d_%d.ckpt" % (save_path, i + 1, n_total_iter)) logging.info("save model : %s/model_view%d_%d.ckpt" % (save_path, i + 1, n_total_iter)) n_total_iter += 1 for i in range(view_num): torch.save( model_multi_views[i].state_dict(), "%s/model_view%d_%d.ckpt" % (save_path, i + 1, n_total_iter)) logging.info("save model : %s/model_view%d_%d.ckpt" % (save_path, i + 1, n_total_iter)) writer.close()