def train(train_sources, eval_source): path = sys.argv[1] dr = DataReader(path, train_sources) dr.read() print(len(dr.train.x)) batch_size = 8 device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') dataset_s_train = MultiDomainDataset(dr.train.x, dr.train.y, dr.train.vendor, device, DomainAugmentation()) dataset_s_dev = MultiDomainDataset(dr.dev.x, dr.dev.y, dr.dev.vendor, device) dataset_s_test = MultiDomainDataset(dr.test.x, dr.test.y, dr.test.vendor, device) loader_s_train = DataLoader(dataset_s_train, batch_size, shuffle=True) dr_eval = DataReader(path, [eval_source]) dr_eval.read() dataset_eval_dev = MultiDomainDataset(dr_eval.dev.x, dr_eval.dev.y, dr_eval.dev.vendor, device) dataset_eval_test = MultiDomainDataset(dr_eval.test.x, dr_eval.test.y, dr_eval.test.vendor, device) dataset_da_train = MultiDomainDataset(dr.train.x+dr_eval.train.x, dr.train.y+dr_eval.train.y, dr.train.vendor+dr_eval.train.vendor, device, DomainAugmentation()) loader_da_train = DataLoader(dataset_da_train, batch_size, shuffle=True) segmentator = UNet() discriminator = Discriminator(n_domains=len(train_sources)) discriminator.to(device) segmentator.to(device) sigmoid = nn.Sigmoid() selector = Selector() s_criterion = nn.BCELoss() d_criterion = nn.CrossEntropyLoss() s_optimizer = optim.AdamW(segmentator.parameters(), lr=0.0001, weight_decay=0.01) d_optimizer = optim.AdamW(discriminator.parameters(), lr=0.001, weight_decay=0.01) a_optimizer = optim.AdamW(segmentator.encoder.parameters(), lr=0.001, weight_decay=0.01) lmbd = 1/150 s_train_losses = [] s_dev_losses = [] d_train_losses = [] eval_domain_losses = [] train_dices = [] dev_dices = [] eval_dices = [] epochs = 3 da_loader_iter = iter(loader_da_train) for epoch in tqdm(range(epochs)): s_train_loss = 0.0 d_train_loss = 0.0 for index, sample in enumerate(loader_s_train): img = sample['image'] target_mask = sample['target'] da_sample = next(da_loader_iter, None) if epoch == 100: s_optimizer.defaults['lr'] = 0.001 d_optimizer.defaults['lr'] = 0.0001 if da_sample is None: da_loader_iter = iter(loader_da_train) da_sample = next(da_loader_iter, None) if epoch < 50 or epoch >= 100: # Training step of segmentator predicted_activations, inner_repr = segmentator(img) predicted_mask = sigmoid(predicted_activations) s_loss = s_criterion(predicted_mask, target_mask) s_optimizer.zero_grad() s_loss.backward() s_optimizer.step() s_train_loss += s_loss.cpu().detach().numpy() if epoch >= 50: # Training step of discriminator predicted_activations, inner_repr = segmentator(da_sample['image']) predicted_activations = predicted_activations.clone().detach() inner_repr = inner_repr.clone().detach() predicted_vendor = discriminator(predicted_activations, inner_repr) d_loss = d_criterion(predicted_vendor, da_sample['vendor']) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() d_train_loss += d_loss.cpu().detach().numpy() if epoch >= 100: # adversarial training step predicted_mask, inner_repr = segmentator(da_sample['image']) predicted_vendor = discriminator(predicted_mask, inner_repr) a_loss = -1 * lmbd * d_criterion(predicted_vendor, da_sample['vendor']) a_optimizer.zero_grad() a_loss.backward() a_optimizer.step() lmbd += 1/150 inference_model = nn.Sequential(segmentator, selector, sigmoid) inference_model.to(device) inference_model.eval() d_train_losses.append(d_train_loss / len(loader_s_train)) s_train_losses.append(s_train_loss / len(loader_s_train)) s_dev_losses.append(calculate_loss(dataset_s_dev, inference_model, s_criterion, batch_size)) eval_domain_losses.append(calculate_loss(dataset_eval_dev, inference_model, s_criterion, batch_size)) train_dices.append(calculate_dice(inference_model, dataset_s_train)) dev_dices.append(calculate_dice(inference_model, dataset_s_dev)) eval_dices.append(calculate_dice(inference_model, dataset_eval_dev)) segmentator.train() date_time = datetime.now().strftime("%m%d%Y_%H%M%S") model_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "model", "weights", "segmentator"+str(date_time)+".pth") torch.save(segmentator.state_dict(), model_path) util.plot_data([(s_train_losses, 'train_losses'), (s_dev_losses, 'dev_losses'), (d_train_losses, 'discriminator_losses'), (eval_domain_losses, 'eval_domain_losses')], 'losses.png') util.plot_dice([(train_dices, 'train_dice'), (dev_dices, 'dev_dice'), (eval_dices, 'eval_dice')], 'dices.png') inference_model = nn.Sequential(segmentator, selector, sigmoid) inference_model.to(device) inference_model.eval() print('Dice on annotated: ', calculate_dice(inference_model, dataset_s_test)) print('Dice on unannotated: ', calculate_dice(inference_model, dataset_eval_test))
def train(args): result_path = 'result/%s/'%args.model if not os.path.exists(result_path): os.makedirs(result_path) os.makedirs('%simage'%result_path) os.makedirs('%scheckpoint'%result_path) train_set = MyDataset('train', args.label_type, 512) train_loader = DataLoader( train_set, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers) # device = 'cuda:0' device = 'cuda:6' if torch.cuda.device_count()>1 else 'cuda:0' out_channels = 1 if args.label_type=='msk' else 2 print(out_channels) if args.model=='unet': print('using unet as model!') model = UNet(out_channels=out_channels) elif args.model=='deeplab': print('using deeplab as model!') model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False) else: print('no model!') model = model.to(device) # model = nn.DataParallel(model) img_show = train_set.__getitem__(0)['x'] img_show = torch.tensor(img_show).to(device).float()[None, :] model.train() criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) loss_list = [] loss_best = 10 for epo in tqdm(range(1, args.epochs+1), ascii=True): epo_loss = [] for idx, item in enumerate(train_loader): x = item['x'].to(device, dtype=torch.float) y = item['y'].to(device, dtype=torch.float) optimizer.zero_grad() if args.model=='unet': pred = model(x) elif args.model=='deeplab': pred = model(x)['out'][:,0][:,None] # print(y.shape, pred.shape) loss = criterion(pred, y) # print(loss.item()) epo_loss.append(loss.data.item()) loss.backward() optimizer.step() epo_loss_mean = np.array(epo_loss).mean() # print(epo_loss_mean) loss_list.append(epo_loss_mean) plot_loss(loss_list, '%simage/loss.png'%result_path) with torch.no_grad(): if args.model=='unet': pred = model(img_show.clone()) elif args.model=='deeplab': pred = model(img_show.clone())['out'][:,0][:,None] # y = model(img_show) # print(img_show.shape) if args.label_type=='msk': x = img_show[0].cpu().detach().numpy().transpose((1,2,0)) y = pred[0, 0].cpu().detach().numpy() elif args.label_type=='flow': x = img_show[0].cpu().detach().numpy().transpose((1,2,0)) y = pred[0].cpu().detach().numpy().transpose((1,2,0)) plt.subplot(121) plt.imshow(x*255) plt.subplot(122) plt.imshow(y[:,:,0]) plt.savefig('%simage/%d.png'%(result_path, epo)) plt.clf() #loss if epo % 3 ==0: torch.save(model, '%scheckpoint/%d.pt'%(result_path, epo)) if epo_loss_mean < loss_best: loss_best = epo_loss_mean torch.save(model, '%scheckpoint/best.pt'%(result_path)) np.save('%sloss.npy'%result_path, np.array(loss_list))