def get_result(net, gpu=False): ids = get_ids(dir_img) val = get_imgs_and_masks(ids, dir_img, dir_mask, 1.0) val_dice = eval_net(net, val, gpu) print('Validation Dice Coeff: {}'.format(val_dice))
def batch_calc(args): from eval import eval_net # get net net = UNet(n_channels=3, n_classes=num_classes, bilinear=True) logging.info(f'Loading model {args.model}') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if args.device: device = torch.device(args.device) logging.info(f'Using device {device}') net.to(device=device) net.load_state_dict(torch.load(args.model, map_location=device)) logging.info('Model loaded !') val_dataset = CityscapesDataset(type='val', scale=args.scale) val_loader = DataLoader(val_dataset, batch_size=args.batchsize, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) miou, ious, hist = eval_net(net, val_loader, device, type='miou') logging.info(f'total mIoU value: {miou}\n' f'per category\'s IoU value: \n{ious}\n' f'hist save as {args.histname}') np.savetxt(path.join('./test/hist/', args.histname), hist, delimiter=',') return miou
def train_net(net, train_data, val_data, optimizer, lr_scheduler, loss_fn, epochs, gpu=True, save_model=None, save_image=False, print_step=100): num_batch = len(train_data) for epoch in range(epochs): net.train() print("Epoch %d/%d" % (epoch + 1, epochs)) epoch_loss = 0.0 for step, data in enumerate(train_data): imgs, masks = data['img'], data['mask'] if gpu: imgs = imgs.cuda() masks = masks.cuda() pred_masks = net(imgs) pred_masks_flat = pred_masks.view(-1) masks_flat = masks.view(-1) loss = loss_fn(pred_masks_flat, masks_flat) epoch_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() if (step + 1) % print_step == 0: print("Step %d/%d Loss: %.4f" % (step + 1, num_batch, epoch_loss / (step + 1))) print("Epoch %d/%d Finished! Train loss: %.4f" % (epoch + 1, epochs, epoch_loss / (step + 1))) lr_scheduler.step(epoch_loss) val_dice, val_loss = eval_net(net, val_data, loss_fn, gpu, save_image, epoch=epoch + 1) print('Validation dice coeff is %f, loss is %.4f' % (val_dice, val_loss)) if save_model: if not os.path.exists(save_model): os.mkdir(save_model) save_path = os.path.join( save_model, "epoch_%d_dice_%f.pth" % (epoch + 1, val_dice)) torch.save(net.state_dict(), save_path)
def train_net(net, args, epochs, batch_size, lr, device): train = BasicDataset(args, False) val = BasicDataset(args, True) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True) val_loader = DataLoader(val, batch_size=1, shuffle=False, pin_memory=True, drop_last=True) n_val = len(val) n_train = len(train) # writer = SummaryWriter('./records/tensorboard') global_step = 0 optimizer = optim.Adam(net.parameters(), lr=lr) # criterion = nn.BCEWithLogitsLoss() for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] masks = batch['mask'] name = batch['name'][0] imgs = imgs.to(device=device, dtype=torch.float32) masks = masks.to(device=device, dtype=torch.float32) # masks_pred = net(imgs) loss = criterion(masks_pred, masks) pbar.set_postfix(**{'loss': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() pbar.update(imgs.shape[0]) # global_step += 1 writer.add_scalar('train', loss.item(), global_step) if (global_step % (n_train // (batch_size)) == 0): valid_score = eval_net(args, net, val_loader, device) # scheduler.step(val_score) writer.add_scalar('valid', valid_score, global_step) torch.save( net.state_dict(), args.checkpoints_dir + 'naive_baseline_{}.pth'.format(epoch))
def validation(self, number_of_train_data, epoch): loss = self.epoch_loss / (number_of_train_data + 1) print("Epoch finished ! Loss: {}".format(loss)) torch.save( self.net.state_dict(), str( self.save_weight_path.parent.joinpath( "epoch_weight/{:05d}.pth".format(epoch))), ) val_loss = eval_net( self.net, self.val_loader, self.gpu, self.loss_flag, self.vis, self.img_view_val, self.gt_view_val, self.pred_view_val, self.criterion, ) print("val_loss: {}".format(val_loss)) try: if min(self.val_losses) > val_loss: torch.save(self.net.state_dict(), str(self.save_weight_path)) self.bad = 0 print("update bad") with self.save_weight_path.parent.joinpath("best.txt").open( "w") as f: f.write("{}".format(epoch)) pass else: self.bad += 1 print("bad ++") except ValueError: torch.save(self.net.state_dict(), str(self.save_weight_path)) self.val_losses.append(val_loss) if self.need_vis: self.update_vis_plot( iteration=epoch, loss=loss, val=[loss, val_loss], window1=self.iter_plot, window2=self.epoch_plot, update_type="append", ) print("bad = {}".format(self.bad)) self.epoch_loss = 0
def validation(net, imgs, true_masks, masks_pred, writer, val_loader, n_val, global_step): val_score = eval_net(net, val_loader, device, n_val) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info('Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
def validation_only(net, device, batch_size=1, img_width=0, img_height=0, img_scale=1.0, use_bw=False, standardize=False, compute_statistics=False): load_statstics = not compute_statistics dataset = BasicDataset(dir_img_test, dir_mask_test, img_width, img_height, img_scale, use_bw, standardize=standardize, load_statistics=load_statstics, save_statistics=True) val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) val_score = eval_net(net, val_loader, device) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) else: logging.info('Validation Dice Coeff: {}'.format(val_score))
def test_net(net, device, batch_size=4, scale=512, threshold=0.5): dataset = BasicDataset(dir_img, dir_mask, 512, False, 5) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) tm = TimeManager() val_score, precision, recall = eval_net(net, loader, device, threshold) if net.n_classes > 1: print('Validation cross entropy:', val_score) else: print('Validation Dice Coeff:', val_score) print('Validation Precision:', precision) print('Validation Recall:', recall) tm.show()
def train_net(net, epochs=5, batch_size=1, lr=0.1, val_percent=0.05, save_cp=True, gpu=False, img_scale=0.5): # dir_img = 'data/train/' # dir_mask = 'data/train_masks/' dir_img = 'E:/git/dataset/tgs-salt-identification-challenge/train/images/' dir_mask = 'E:/git/dataset/tgs-salt-identification-challenge/train/masks/' # dir_img = 'E:/git/dataset/tgs-salt-identification-challenge/train/my_images/' # dir_mask = 'E:/git/dataset/tgs-salt-identification-challenge/train/my_masks/' dir_checkpoint = 'checkpoints/' ids = get_ids(dir_img) ids = split_ids(ids) iddataset = split_train_val(ids, val_percent) print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, len(iddataset['train']), len(iddataset['val']), str(save_cp), str(gpu))) N_train = len(iddataset['train']) optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) criterion = nn.BCELoss() for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() # reset the generators train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale) epoch_loss = 0 for i, b in enumerate(batch(train, batch_size)): imgs = np.array([i[0] for i in b]).astype(np.float32) # true_masks = np.array([i[1] for i in b])#np.rot90(m) true_masks = np.array([i[1].T / 65535 for i in b]) #np.rot90(m) # show_batch_image(true_masks) imgs = torch.from_numpy(imgs) true_masks = torch.from_numpy(true_masks) if gpu: imgs = imgs.cuda() true_masks = true_masks.cuda() # show_batch_image(imgs) masks_pred = net(imgs) masks_probs_flat = masks_pred.view(-1) true_masks_flat = true_masks.view(-1) loss = criterion(masks_probs_flat, true_masks_flat) epoch_loss += loss.item() print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item())) optimizer.zero_grad() loss.backward() optimizer.step() print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) if 1: val_dice = eval_net(net, val, gpu) print('Validation Dice Coeff: {}'.format(val_dice)) if save_cp: torch.save(net.state_dict(), dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) print('Checkpoint {} saved !'.format(epoch + 1))
def train_net( net, writer, load, epochs=5, batch_size=1, lr=0.1, val_percent=0.1, save_cp=False, gpu=True, ): image_dir = 'train/images_cut/' mask_dir = 'train/masks_cut/' checkpoint_dir = 'checkpoints/' name_list = get_names(image_dir) split_list = train_val_split(name_list, val_percent) print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, len(split_list['train']), len(split_list['val']), str(save_cp), str(gpu))) N_train = len(split_list['train']) optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=0.005) print('Model loaded from {}'.format(args.load)) model_dict = net.state_dict() pretrained_dict = torch.load(args.load) # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) net.load_state_dict(model_dict) train_params = [] if args.fix: print("fixing parameters") for k, v in net.named_parameters(): train_params.append(k) pref = k[:12] if pref == 'module.conv1' or pref == 'module.conv2': v.requires_grad = False train_params.remove(k) optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.005) criterion = mixloss() for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() train = get_train_pics(image_dir, mask_dir, split_list) epoch_loss = 0 for i, samps in enumerate(batch(train, batch_size)): images = np.array([samp['image'] for samp in samps]) masks = np.array([samp['mask'] for samp in samps]) images = torch.from_numpy(images).type(torch.FloatTensor) masks = torch.from_numpy(masks).type(torch.FloatTensor) if gpu: images = images.cuda() true_masks = masks.cuda() masks_pred = net(images) masks_probs_flat = masks_pred.view(-1) true_masks_flat = true_masks.view(-1) loss = criterion(masks_probs_flat, true_masks_flat) epoch_loss += loss.item() print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item())) optimizer.zero_grad() loss.backward() optimizer.step() avg_train_loss = epoch_loss / i print('Epoch finished ! Loss: {}'.format(avg_train_loss)) val = get_val_pics(image_dir, mask_dir, split_list) if 1: val_iou, val_ls = eval_net(net, val, gpu) print('Validation IoU: {} Loss:{}'.format(val_iou, val_ls)) writer.add_scalar('train/loss', avg_train_loss, epoch) writer.add_scalar('val/loss', val_ls, epoch) writer.add_scalar('val/IoU', val_iou, epoch) torch.save(net.state_dict(), checkpoint_dir + 'CP{}.pth'.format(epoch + 1)) print('Checkpoint {} saved !'.format(epoch + 1))
def train_net( net, epochs=5, batch_size=2, lr=0.0001, save_cp=True, gpu=True, target_path='', checkpoint_path='/mnt/HDD1/Frederic/Segmentation/Seg_deepv3/checkpoints/' ): #Set path to store checkpoint dir_checkpoint = checkpoint_path result_path = result_path_global #Print training details print(''' Get Start, training details: Epochs: {} Batch size: {} Learning rate: {} Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, str(save_cp), str(gpu))) #loss function and optimizer optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9) criterion = nn.CrossEntropyLoss() # criterion = nn.BCELoss() #Train iteration logger = Logger(result_path + 'log.txt', title='ISIC2016_U_Net') logger.set_names(['Epochs', 'Avg_Trainning_Loss', 'Val_Dice_coefficient']) #load data val_sets = load_validation_data() start_epoch = args.start_epoch best_dice, best_epoch = 0, 0 for epoch in range(start_epoch, start_epoch + epochs): net.train() #use epoch_loss to store total loss for whole iteration trainloader, datasize = load_train_data(args.batchsize) epoch_loss = 0 if epoch == 75 or epoch == 150 or epoch == 225: lr = lr * 0.1 optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9) with tqdm(total=datasize / batch_size) as pbar: for ite, data in enumerate(trainloader[0]): imgs = data[0] true_masks = data[1] if gpu: imgs = imgs.cuda() true_masks = true_masks.cuda() masks_pred = net(imgs) true_masks = true_masks.squeeze(dim=1) # loss = DiceLoss(masks_pred,true_masks) loss = criterion(masks_pred, true_masks.long()) epoch_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() #set bar for training pbar.set_description( 'Epoch:[%d|%d],loss: %.4f ' % (epoch + 1, args.epochs + start_epoch, loss)) pbar.update(1) avg_loss = epoch_loss / ite print('Epoch finished ! Loss: {}'.format(avg_loss)) # save the sample output every 40 epochs save_sample_mask = False if (epoch + 1) % 50 == 0: save_sample_mask = True val_dice = eval_net(net, val_sets, epoch, gpu, save_sample_mask, result_path) print('Validation Dice Coeff: {}'.format(val_dice)) logger.append([epoch + 1, epoch_loss / ite, val_dice]) #save best epoch and checkpoint if best_dice < val_dice: best_dice = val_dice best_epoch = epoch torch.save(net.state_dict(), dir_checkpoint + 'best_checkpoint.pth') print('best checkpoint is epoch {} with dice {} '.format( best_epoch, best_dice)) #save normal epoch if save_cp and (epoch + 1) % 50 == 0: torch.save(net.state_dict(), dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) # print('Checkpoint {} saved !'.format(epoch + 1)) #plot fig after train logger.close() logger.plot()
def train_net(net, epochs=100, batch_size=2, lr=0.02, val_percent=0.05, cp=True, gpu=False): dir_img = '/home/wdh/DataSets/hand-segmentation/GTEA_gaze_part/Resize/Images/' dir_mask = '/home/wdh/DataSets/hand-segmentation/GTEA_gaze_part/Resize/Masks_1/' dir_checkpoint = 'checkpoints/' ids = get_ids(dir_img) ids = split_ids(ids) iddataset = split_train_val(ids, val_percent) print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, len(iddataset['train']), len(iddataset['val']), str(cp), str(gpu))) N_train = len(iddataset['train']) optimizer = optim.Adam(net.parameters(),lr=lr,betas=(0.9,0.99)) criterion = nn.BCELoss() for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) # reset the generators train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask) epoch_loss = 0 if 1: val_dice = eval_net(net, val, gpu) print('Validation Dice Coeff: {}'.format(val_dice)) for i, b in enumerate(batch(train, batch_size)): X = np.array([i[0] for i in b]) y = np.array([i[1] for i in b]) X = torch.FloatTensor(X) y = torch.ByteTensor(y) if gpu: X = Variable(X).cuda() y = Variable(y).cuda() else: X = Variable(X) y = Variable(y) y_pred = net(X) probs = F.sigmoid(y_pred) probs_flat = probs.view(-1) y_flat = y.view(-1) loss = criterion(probs_flat, y_flat.float()) epoch_loss += loss.data[0] print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.data[0])) optimizer.zero_grad() loss.backward() optimizer.step() print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) if cp: torch.save(net.state_dict(), dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) print('Checkpoint {} saved !'.format(epoch + 1))
torch.save(net.state_dict(), save_file) log.info('Saving pruned to {}...'.format(save_file)) save_txt = osp.join(save_dir, "pruned_channels.txt") pruner.channel_save(save_txt) log.info('Pruned channels to {}...'.format(save_txt)) del net, pruner net = UNet(n_channels=3, n_classes=1, f_channels=save_txt) log.info("Re-Built model using {}...".format(save_txt)) if args.gpu: net.cuda() if args.load: net.load_state_dict(torch.load(save_file)) log.info('Re-Loaded checkpoint from {}...'.format(save_file)) optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0005) # Use epochs or iterations for fine-tuning save_file = osp.join(save_dir, "Finetuned.pth") finetune(net, optimizer, criterion, iddataset['train'], log, save_file, args.iters, args.epochs, args.batch_size, args.gpu, args.scale) val_dice = eval_net(net, val, len(iddataset['val']), args.gpu, args.batch_size) log.info('Validation Dice Coeff: {}'.format(val_dice))
def train_net(net, epochs=5, batch_size=1, lr=0.1, val_percent=0.2, save_cp=True, gpu=False, img_scale=0.5): path = [['data/ori1/', 'data/gt1/'], ['data/original1/', 'data/ground_truth1/'], ['data/Original/', 'data/Ground_Truth/']] dir_img = path[0][0] dir_mask = path[0][1] dir_checkpoint = 'sdgcheck/' ids = get_ids(dir_img) ids = split_ids(ids) iddataset = split_train_val(ids, val_percent) print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, len(iddataset['train']), len(iddataset['val']), str(save_cp), str(gpu))) N_train = len(iddataset['train']) optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.7, weight_decay=0.005) ''' optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=0.0005) ''' criterion = nn.BCELoss() for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() # reset the generators train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale) epoch_loss = 0 x = 0 for i, b in enumerate(batch(train, batch_size)): imgs = np.array([i[0] for i in b]).astype(np.float32) true_masks = np.array([i[1] for i in b]) ''' ori=np.transpose(imgs[0], axes=[1, 2, 0]) scipy.misc.imsave("ori/ori_"+str(x)+'.jpg', ori) gt = np.stack((true_masks[0],)*3, axis=-1) #gt=np.transpose(true_masks[0], axes=[1, 2, 0]) scipy.misc.imsave("gt/gt_"+str(x)+'.jpg', gt) ''' imgs = torch.from_numpy(imgs) true_masks = torch.from_numpy(true_masks) x += 1 if gpu: imgs = imgs.cuda() true_masks = true_masks.cuda() masks_pred = net(imgs) masks_probs_flat = masks_pred.view(-1) true_masks_flat = true_masks.view(-1) loss = criterion(masks_probs_flat, true_masks_flat) epoch_loss += loss.item() print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item())) optimizer.zero_grad() loss.backward() optimizer.step() print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) if 1: val_dice = eval_net(net, val, gpu) print('Validation Dice Coeff: {}'.format(val_dice)) if save_cp: torch.save(net.state_dict(), dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) print('Checkpoint {} saved !'.format(epoch + 1))
def train_net(net, device, epochs=300, batch_size=1, lr=0.1, val_percent=0.5, save_cp=True, img_scale=0.366): #histogram_matching() dataset = BasicDataset(dir_img, dir_mask, img_scale) #n_val = int(len(dataset) * val_percent) #n_train = len(dataset) - n_val #train, val = random_split(dataset, [n_train, n_val]) n_val = 1 n_train = 1 #train = dataset val = list(dataset)[0:1] train = list(dataset)[0:1] train_loader = DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') #optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8) optimizer = torch.optim.SGD(net.parameters(),lr = args.lr,momentum = 0.9) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss() for epoch in range(epochs): adjust_learning_rate(optimizer, epoch) lr = optimizer.param_groups[0]['lr'] net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] #true_mask_pro = true_masks.squeeze().numpy().astype(np.uint8) #img_show = sitk.GetImageFromArray(true_mask_pro) #sitk.WriteImage(img_show, './data/pred/scale{}_mask.nii'.format(img_scale)) ''' imgs_pnp = imgs.squeeze().cpu().detach().numpy() imgs_show = sitk.GetImageFromArray(imgs_pnp) sitk.WriteImage(imgs_show, f'./data/debug/epoch300_41Img/input_round1/{epoch+1}_{global_step+1}.nii.gz') img_pro = (imgs*32768).squeeze().numpy().astype(np.int16) true_masks_pnp = true_masks.squeeze().cpu().detach().numpy() true_masks_show = sitk.GetImageFromArray(true_masks_pnp) sitk.WriteImage(true_masks_show, f'./data/debug/epoch300_41Img/gt_round1/{epoch+1}_{global_step+1}.nii.gz') img_show = sitk.GetImageFromArray(img_pro) #sitk.WriteImage(img_show, './data/pred/scale{}_input.nii'.format(img_scale)) ''' assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) # if (epoch+1) % 20 == 0: # masks_pnp = masks_pred.squeeze().cpu().detach().numpy().astype(np.int16) # masks_show = sitk.GetImageFromArray(masks_pnp) # sitk.WriteImage(masks_show, f'./data/debug/scale0.5_1Img/model_out/{epoch+1}_{global_step+1}.nii') loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (n_train // (batch_size)) == 0: #if global_step % (1) == 0: val_score = eval_net(net, val_loader, device, n_val) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info('Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step) #writer.add_images('images', imgs, global_step) #if net.n_classes == 1: #writer.add_images('masks/true', true_masks, global_step) #writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if save_cp: if (epoch+1) % 10 == 0: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), dir_checkpoint + f'72Img_80160_f4/' +f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close()
def train_net( net, device, writer_detail, writer_main, random_seed, epochs=5, batch_size=1, lr=0.001, save_cp=True, val_i=4, ): train_dataset, val_dataset, train_list, val_list = make_dataset( root_dir, dataset_dir, val_i) n_val = val_dataset.__len__() n_train = train_dataset.__len__() train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True, worker_init_fn=np.random.seed(random_seed)) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False, worker_init_fn=np.random.seed(random_seed)) global_step = 0 logging.info(f'''Starting training: val_i: {val_i} Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} ''') # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) optimizer = optim.Adam(params=net.parameters(), lr=lr) # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if n_classes > 1 else 'max', patience=2) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=np.power( 0.1, 1 / epochs)) if n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss(reduction='mean') best_score = 0 best_net = net best_e = 0 for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] contour_masks = batch['mask_contour'] true_biMasks = (true_masks > 0).int() contour_masks = (contour_masks > 0).int() assert imgs.shape[1] == n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if n_classes == 1 else torch.long true_biMasks = true_biMasks.to(device=device, dtype=mask_type) contour_masks = contour_masks.to(device=device, dtype=mask_type) # combine_masks = torch.cat([true_biMasks, contour_masks], dim=1) # masks_pred = net(imgs) loss_mask = criterion(masks_pred[:, 0:1], true_biMasks) loss_contour = criterion(masks_pred[:, 1:2], contour_masks) loss = loss_mask + 10 * loss_contour epoch_loss += loss.item() writer_main.add_scalar('val_%d_Loss/train' % val_i, loss.item(), global_step) pbar.set_postfix(loss_mask=loss_mask.item(), loss_contour=loss_contour.item()) optimizer.zero_grad() loss.backward() # nn.utils.clip_grad_value_(net.parameters(), 0.1) optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % ((n_train + n_val) // (2 * batch_size)) == 0: for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer_detail.add_histogram( 'val_%d_weights/' % val_i + tag, value.data.cpu().numpy(), global_step) writer_detail.add_histogram( 'val_%d_grads/' % val_i + tag, value.grad.data.cpu().numpy(), global_step) writer_detail.add_images('val_%d_images' % val_i, imgs[0:1], global_step) if n_classes == 1: writer_detail.add_images('val_%d_masks/true' % val_i, true_biMasks[0:1], global_step) writer_detail.add_images( 'val_%d_masks/pred' % val_i, torch.sigmoid(masks_pred[0:1, 0:1]) > 0.5, global_step) writer_detail.add_images( 'val_%d_masks_contour/true' % val_i, contour_masks[0:1], global_step) writer_detail.add_images( 'val_%d_masks_contour/pred' % val_i, torch.sigmoid(masks_pred[0:1, 1:2]) > 0.5, global_step) val_score_mask, val_score_contour = eval_net(net, val_loader, device, n_classes) # scheduler.step(val_score) writer_main.add_scalar('val_%d_learning_rate' % val_i, optimizer.param_groups[0]['lr'], global_step) # if val_score>=best_score: # best_score=val_score # best_net=net # best_e=epoch logging.info( 'val {} Validation score mask: {} Validation score contour: {}'. format(val_i, val_score_mask, val_score_contour)) # if val_score<10: writer_main.add_scalar('val_%d_score_mask/test' % val_i, val_score_mask, global_step) writer_main.add_scalar('val_%d_score_contour/test' % val_i, val_score_contour, global_step) scheduler.step() if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save( net.state_dict(), dir_checkpoint + f'val {val_i}_CP_epoch{epoch + 1}_scoreMask_%.4f_scoreContour_%.4f.pth' % (val_score_mask, val_score_contour)) # logging.info(f'val {val_i} Checkpoint {epoch + 1} saved, best score: %.4f !'%best_score) # result_path='result_img' # if not os.path.exists(result_path): # os.mkdir(result_path) # jaccard_vali = predict_netOnDataList(root_dir, best_net, val_list, device, result_path=result_path, vali=val_i, mask_threhold=0.5, vis_flag=True) # logging.info('val %d epoch %d jaccard score: %.6f'%(val_i, best_e+1, jaccard_vali)) # if save_cp: # try: # os.mkdir(dir_checkpoint) # logging.info('Created checkpoint directory') # except OSError: # pass # torch.save(best_net.state_dict(), # dir_checkpoint + f'val {val_i}_CP_epoch{best_e + 1}_score_%.4f_jaccard_%.4f.pth'%(best_score,jaccard_vali)) # logging.info(f'val {val_i} Checkpoint {best_e + 1} saved, best score: %.4f !'%best_score) # return jaccard_vali return 0
def train_net(net, epochs=5, batch_size=1, lr=1e-3, val_percent=0.05, save_cp=True, gpu=False, img_scale=0.5): dir_img = '/home/xyj/data/spacenet/vegas/images_rgb_1300/' dir_mask = '/home/xyj/test/Pytorch-UNet/data/train_mask_point/' dir_checkpoint = 'checkpoints_point/' if not os.path.exists(dir_checkpoint): os.mkdir(dir_checkpoint) # ids = get_ids(dir_img) # 返回train文件夹下文件的名字列表,生成器(except last 4 character,.jpg这样的) with open('train_list.txt', 'r') as f: lines = f.readlines() ids = (i.strip('\n')[:-4] for i in lines) ids = split_ids( ids) # 返回(id, i), id属于ids,i属于range(n),相当于把train的图✖️了n倍多张,是tuple的生成器 iddataset = split_train_val( ids, val_percent ) # validation percentage,是dict = {"train": ___(一个list), "val":___(一个list)} print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, len(iddataset['train']), len(iddataset['val']), str(save_cp), str(gpu))) N_train = len(iddataset['train']) # optimizer = optim.SGD(net.parameters(), # lr=lr, # momentum=0.9, # weight_decay=0.0005) optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-3) # scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=40,gamma = 0.3) criterion = nn.BCELoss() for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() # reset the generators train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale) epoch_loss = 0 for i, b in enumerate(batch(train, batch_size)): imgs = np.array([i[0] for i in b]).astype(np.float32) true_masks = np.array([i[1] // 200 for i in b]) imgs = torch.from_numpy(imgs) true_masks = torch.from_numpy(true_masks) if gpu: imgs = imgs.cuda() true_masks = true_masks.cuda() masks_pred = net(imgs) masks_probs_flat = masks_pred.view(-1) true_masks_flat = true_masks.view(-1) loss = criterion(masks_probs_flat, true_masks_flat) epoch_loss += loss.item() print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item())) optimizer.zero_grad() loss.backward() optimizer.step() # scheduler.step() print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) if 1: val_dice = eval_net(net, val, gpu) print('Validation Dice Coeff: {}'.format(val_dice)) if save_cp: torch.save(net.state_dict(), dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) print('Checkpoint {} saved !'.format(epoch + 1))
def train_net(net, device, epochs=5, batch_size=2, lr=0.0001, val_percent=0.2, save_cp=True, img_scale=1): # Init dataset and train/test split dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) # Call DataLoader train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) # Writer to tensorboard writer = SummaryWriter( comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') # Init optimizer and define lr_scheduler optimizer = optim.Adam(net.parameters(), lr=lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) # In this version, we use BCEWithLogitsLoss criterion = nn.BCEWithLogitsLoss() for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] # Set-up device imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) # Forward masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 0.1) optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (n_train // (2 * batch_size)) == 0: # if global_step % 100 == 0: # Track weight and gradient for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) val_score, val_score_iou = eval_net( net, val_loader, device) scheduler.step(val_score) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) # Visualize val scores logging.info('Validation Dice Coeff: {}'.format(val_score)) logging.info( 'Validation IoU Coeff: {}'.format(val_score_iou)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_scalar('IoU/test', val_score_iou, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) # test with sample images test_folder = 'test/test_epoch_{}_new'.format(epoch) os.makedirs(test_folder) dirs = os.listdir('test/test_set') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') for file in dirs: img = Image.open(os.path.join('test/test_set', file)) mask = predict_img(net=net, full_img=img, device=device) result = mask_to_image(mask) result.save(os.path.join(test_folder, file)) if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close()
def train_net(net, epochs=5, batch_size=1, lr=0.1, val_percent=0.05, save_cp=True, gpu=False, img_scale=0.5, dir_img=None, dir_mask=None, dir_checkpoint=None, channels=1, classes=1): ids = os.listdir(dir_img) if not os.path.exists(dir_checkpoint): os.makedirs(dir_checkpoint, mode=0o755) iddataset = split_train_val(ids, val_percent) print('Starting training:') print('Epochs: ' + str(epochs)) print('Batch size: ' + str(batch_size)) print('Learning rate: ' + str(lr)) print('Training size: ' + str(len(iddataset['train']))) print('Validation size: ' + str(len(iddataset['val']))) print('Checkpoints: ' + str(save_cp)) N_train = len(iddataset['train']) optimizer = optim.RMSprop(net.parameters(), lr=lr) criterion = nn.BCELoss() for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() # reset the generators train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale) epoch_loss = 0 # Run Batch for i, b in enumerate(batch(train, batch_size)): # Grab data try: imgs = np.array([i[0] for i in b]).astype(np.float32) true_masks = np.array([i[1] for i in b]) except: print( 'prob have dimension issues, wrong orientations or half reconned images' ) # Deal with dimension issues if channels == 1: imgs = np.expand_dims(imgs, 1) if classes > 1: true_masks = to_categorical(true_masks, num_classes=classes) # Play in torch's sandbox imgs = torch.from_numpy(imgs) true_masks = torch.from_numpy(true_masks) # Send to GPU if gpu: imgs = imgs.cuda() true_masks = true_masks.cuda() # Predicted segmentations masks_pred = net(imgs) # Flatten masks_probs_flat = masks_pred.view(-1) true_masks_flat = true_masks.view(-1) # Calculate losses btwn true/predicted loss = criterion(masks_probs_flat, true_masks_flat) epoch_loss += loss.item() # Batch Loss print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item())) # Backprop optimizer.zero_grad() loss.backward() optimizer.step() # Epoch Loss print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) if 1: val_dice = eval_net(net, val, epoch, dir_checkpoint, gpu) print('Validation Dice Coeff: {}'.format(val_dice)) if save_cp: torch.save( net.state_dict(), os.path.join(dir_checkpoint, 'CP{}.pth'.format(epoch + 1))) print('Checkpoint {} saved !'.format(epoch + 1))
def train_net(net, device, epochs=5, batch_size=1, lr=0.1, val_percent=0.1, save_cp=True, img_scale=0.5): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) writer = SummaryWriter( comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8) print("net class num: {}".format(net.n_classes)) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss() for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) # print("predict mask: {}".format(masks_pred.size())) # print("true mask: {}".format(true_masks.size())) # print(type(true_masks)) # a = np.array(true_masks) # # cnt = 0 # # for row in a[0]: # np.savetxt("a.csv", a[0], delimiter=",") # # cnt += 1 # true_masks = torch.LongTensor(np.zeros((1,250,250))) # true_masks = true_masks.to(device=device, dtype=mask_type) print("predict mask: {}".format(masks_pred.size())) print("true mask: {}".format(true_masks.size())) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) print('success==========================') pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (len(dataset) // (10 * batch_size)) == 0: val_score = eval_net(net, val_loader, device, n_val) if net.n_classes > 1: logging.info( 'Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info( 'Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close()
def train_net(net, device, epochs=100, batch_size=1, lr=0.1, val_percent=0.2, save_cp=True, img_scale=1): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) gene_eval_data(val_loader, dir='./data/val/') writer = SummaryWriter( comment='LR_{}_BS_{}_SCALE_{}'.format(lr, batch_size, img_scale)) global_step = 0 logging.info('''Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} Device: {} Images scaling: {} '''.format(epochs, batch_size, lr, n_train, n_val, save_cp, device.type, img_scale)) # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min' if net.n_classes > 1 else 'max', factor=0.5, patience=20) criterion = dice_loss # criterion = nn.BCELoss() last_loss = 9999 last_val_score = 0 for epoch in range(epochs): net.train() epoch_loss = 0 step = 0 mybatch_size = 4 with tqdm(total=n_train, desc='Epoch {}/{}'.format(epoch + 1, epochs), unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels,\ 'Network has been defined with {} input channels, '.format(net.n_channels)+\ 'but loaded images have {} channels. Please check that '.format(imgs.shape[1])+\ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 0.1) global_step += 1 writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) step += 1 if step % mybatch_size == 0: optimizer.step() optimizer.zero_grad() step = 0 pbar.update(imgs.shape[0]) # if global_step % (len(dataset) // ( 2* batch_size)) == 0: for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) val_score = eval_net(net, val_loader, device) scheduler.step(val_score) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info('Train Loss: {} Validation Dice Coeff: {} '.format( epoch_loss / n_train, val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.3, global_step) if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass if last_loss > epoch_loss or last_val_score < val_score: last_loss = min(last_loss, epoch_loss) last_val_score = max(last_val_score, val_score) # torch.save(net.state_dict(), torch.save( net, dir_checkpoint + 'CP_epoch{}Trainloss{}ValDice{}.pt'.format( epoch + 1, epoch_loss / n_train, val_score)) logging.info('Checkpoint {} saved !'.format(epoch + 1) + ' CP_epoch{}Trainloss{}ValDice{}.pt'.format( epoch + 1, epoch_loss / n_train, val_score)) writer.close()
def train_net(dir_checkpoint, n_classes, n_channels, device, epochs=30, save_cp=True, img_scale=1): global best_val_iou_score global best_test_iou_score net = PAN() net.to(device=device) batch_size = 4 lr = 1e-5 writer = SummaryWriter( comment= f'_{net.__class__.__name__}_LR_{lr}_BS_{batch_size}_categoryFirstEntropy_ACQUISITION' ) global_step = 0 logging.basicConfig( filename="./logging_one32nd_category.txt", filemode='a', format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Checkpoints: {save_cp} Device: {device.type} ''') optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min' if n_classes > 1 else 'max', patience=2) if n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss() num_phases = 25 # total 2689 imgs, within each phase: fetching 100 imgs to training set. training_pool_ids_path = "data_one32nd_category.json" all_training_data = "data_all.json" for phase in range(num_phases): # Within a phase, save the best epoch (having highest test_iou) checkpoint and save its test_iou to TF_Board # also, load the best right previous checkpoint selected_images = get_pool_data(training_pool_ids_path) data_train = RestrictedDataset(dir_img, dir_mask, selected_images) data_test = BasicDataset(imgs_dir=dir_img_test, masks_dir=dir_mask_test, train=False, scale=img_scale) train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) right_previous_ckpt_dir = Path(dir_checkpoint + 'ckpt.pth') if right_previous_ckpt_dir.is_file(): net.load_state_dict( torch.load(dir_checkpoint + 'ckpt.pth', map_location=device)) for epoch in range(epochs): net.train() epoch_loss = 0 n_train = len(data_train) with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == n_channels, \ f'Network has been defined with {n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) # return BCHW = 8_1_256_256 _tem = net(imgs) # print("IS DIFFERENT OR NOT: ", torch.sum(masks_pred - _tem)) true_masks = true_masks[:, :1, :, :] loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() # writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 0.1) optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 # Tính dice và iou score trên tập Test set, ghi vào tensorboard . test_score_dice, test_score_iou = eval_net(net, test_loader, n_classes, device) if test_score_iou > best_test_iou_score: best_test_iou_score = test_score_iou try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save( net.state_dict(), dir_checkpoint + f'best_CP_epoch{epoch + 1}_one32th_.pth') logging.info(f'Checkpoint {epoch + 1} saved !') logging.info('Test Dice Coeff: {}'.format(test_score_dice)) print('Test Dice Coeff: {}'.format(test_score_dice)) writer.add_scalar(f'Phase_{phase}_Dice/test', test_score_dice, epoch) logging.info('Test IOU : {}'.format(test_score_iou)) print('Test IOU : {}'.format(test_score_iou)) writer.add_scalar(f'Phase_{phase}_IOU/test', test_score_iou, epoch) print(f"Phase_{phase}_best iou: ", best_test_iou_score) torch.save(net.state_dict(), dir_checkpoint + 'ckpt.pth') writer.add_scalar('Phase_IOU/test', best_test_iou_score, phase) # Fetching data for next phase - Update pooling images. update_training_pool_ids_2(net, training_pool_ids_path, all_training_data, device, acquisition_func="cfe") writer.close()
def train_net(net, device, epochs=5, batch_size=1, lr=0.1, val_percent=0.1, save_cp=True, img_scale=0.5): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) #writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss() best_score = 0 for epoch in range(epochs): net.train() epoch_loss = 0 start = time.time() with tqdm(total=n_train, desc=f'Epoch {epoch}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() #writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() pbar.update(imgs.shape[0]) cost_time = time.time() - start logging.info(f"{epoch} loss: {epoch_loss:.5f} time {cost_time:.3f}s") val_score = eval_net(net, val_loader, device, n_val) if net.n_classes > 1: logging.info('Validation cross entropy: {:.5f}'.format(val_score)) #writer.add_scalar('Loss/test', val_score, global_step) else: logging.info('Validation Dice Coeff: {:.5f}'.format(val_score)) #writer.add_scalar('Dice/test', val_score, global_step) #writer.add_images('images', imgs, global_step) # if net.n_classes == 1: # writer.add_images('masks/true', true_masks, global_step) # writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if val_score > best_score: torch.save(net.state_dict(), log_dir + '/best.pth') best_score = val_score logging.info(f'best improved to {val_score:.5f}') torch.save(net.state_dict(), log_dir + "/latest.pth")
def train_net(net, device, figpath, epochs=5, batch_size=1, lr=0.001, val_percent=0.1, save_cp=True, img_scale=0.5, img_size=512, noise_fraction=0): dir_img = 'ISIC-2017_Training_Data/' dir_mask = 'ISIC-2017_Training_Part1_GroundTruth' dir_val_img = 'ISIC-2017_Training_Data_validation/' dir_val_mask = 'ISIC-2017_Training_Part1_GroundTruth_validation/' dir_cle_img = 'ISIC-2017_Training_Data_clean/' dir_cle_mask = 'ISIC-2017_Training_Part1_GroundTruth_validation_clean/' dir_checkpoint = 'checkpoints/' if noise_fraction != 0: dir_mask = dir_mask + '_' + str(noise_fraction) + '/' print(dir_mask) else: dir_mask = dir_mask + '/' print(dir_mask) train = BasicDataset(dir_img, dir_mask, img_scale, img_size) val = BasicDataset(dir_val_img, dir_val_mask, img_scale, img_size) cle = BasicDataset(dir_cle_img, dir_cle_mask, img_scale, img_size) # n_val = int(len(dataset) * val_percent) # n_train = len(dataset) - n_val # train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) cle_loader = DataLoader(cle, batch_size=5, shuffle=False, num_workers=8, pin_memory=True) batch = next(iter(cle_loader)) clean_data = batch['image'] clean_labels = batch['mask'] clean_data = clean_data.to(device=device, dtype=torch.float32) clean_labels = clean_labels.to(device=device, dtype=torch.float32) # clean_data = clean_data.cuda() # clean_labels = clean_labels.cuda() writer = SummaryWriter( comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0 net_losses = [] acc_test = [] acc_train = [] dice_train = [] dice_test = [] loss_train = [] num_batch = len(train_loader) logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {len(train)} Validation size: {len(val)} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} Images size: {img_size} Noise fraction: {noise_fraction} ''') # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=0, momentum=0.99) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) if net.n_classes > 1: # criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss(reduction="none") else: # criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss(reduction="none") for epoch in range(epochs): net.train() tot = 0 num_val = 0 tot_val = 0 epoch_loss = 0 with tqdm(total=len(train), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) with higher.innerloop_ctx(net, optimizer) as (meta_net, meta_opt): y_f_hat = meta_net(imgs) loss = criterion(y_f_hat, true_masks[:, 0:1]) eps = torch.zeros(loss.size()).cuda() eps = eps.requires_grad_() l_f_meta = torch.sum(loss * eps) meta_opt.step(l_f_meta) y_g_hat = meta_net(clean_data) l_g_meta = torch.mean( criterion(y_g_hat, clean_labels[:, 0:1])) grad_eps = torch.autograd.grad( l_g_meta, eps, only_inputs=True, allow_unused=True)[0].detach() w_tild = torch.clamp(-grad_eps, min=0) norm_c = torch.sum(w_tild) if norm_c != 0: w = w_tild / norm_c else: w = w_tild masks_pred = net(imgs) pred = torch.sigmoid(masks_pred) pred = (pred > 0.5).float() # print(pred.size()) # print(true_masks[:, 0:1].size()) tot += dice_coeff(pred, true_masks[:, 0:1]).item() dice_train.append(dice_coeff(pred, true_masks[:, 0:1]).item()) writer.add_scalar('Dice/train', dice_coeff(pred, true_masks[:, 0:1]).item(), global_step) if dice_coeff(pred, true_masks[:, 0:1]).item() <= 0.3: writer.add_images('masks/true', true_masks[:, 0:1], global_step) writer.add_images('masks/pred', pred, global_step) cost = criterion(masks_pred, true_masks[:, 0:1]) loss = torch.sum(cost * w) epoch_loss += loss.item() net_losses.append(loss.item()) writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 0.1) optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (len(train) // (10 * batch_size)) == 0: num_val += 1 for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) val_score = eval_net(net, val_loader, device) dice_test.append(val_score) tot_val += val_score scheduler.step(val_score) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) if net.n_classes > 1: logging.info( 'Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info( 'Validation Dice Coeff: {}'.format(val_score)) print('Step Validation Dice: ', val_score) writer.add_scalar('Dice/test', val_score, global_step) # writer.add_images('images', imgs, global_step) # if net.n_classes == 1: # writer.add_images('masks/true', true_masks, global_step) # writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) print('Epoch: ', epoch) print('Epoch Loss: ', epoch_loss / num_batch) loss_train.append(epoch_loss / num_batch) print('Train EpochDice: ', tot / num_batch) acc_train.append(tot / num_batch) writer.add_scalar('EpochDice/train', tot / num_batch, epoch) print('Val EpochDice: ', tot_val / num_val) acc_test.append(tot_val / num_val) writer.add_scalar('EpochDice/test', tot_val / num_val, epoch) path = dir_checkpoint + figpath + '_' + str(epoch) + '_model.pth' # path = 'baseline/' + str(args.noise_fraction) + '/model.pth' torch.save(net.state_dict(), path) IPython.display.clear_output() fig, axes = plt.subplots(3, 2, figsize=(13, 5)) ax1, ax2, ax3, ax4, ax5, ax6 = axes.ravel() ax1.plot(net_losses, label='iteration_losses') ax1.set_ylabel("Losses") ax1.set_xlabel("Iteration") ax1.legend() ax2.plot(loss_train, label='epoch_losses') ax2.set_ylabel('Losses') ax2.set_xlabel('Epoch') ax2.legend() ax3.plot(acc_train, label='dice_train_epoch') ax3.set_ylabel('EpochDice/train') ax3.set_xlabel('Epoch') ax3.legend() ax4.plot(acc_test, label='dice_test_epoch') ax4.set_ylabel('EpochDice/test') ax4.set_xlabel('Epoch') ax4.legend() ax5.plot(dice_train, label='dice_train_iteration') ax5.set_ylabel('IterationDice/train') ax5.set_xlabel('Iteration') ax5.legend() ax6.plot(dice_test, label='dice_test_iteration') ax6.set_ylabel('IterationDice/test') ax6.set_xlabel('Iteration') ax6.legend() plt.savefig(args.figpath + '.png') writer.close() return net
def train_net(net, epochs=5, batch_size=1, lr=0.01, val_percent=0.05, save_cp=True, gpu=True): # Define directories dir_img = 'E:/Dataset/Dataset10k/images/training/' dir_mask = 'E:/Dataset/Dataset10k/annotations/training/' val_dir_img = 'E:/Dataset/Dataset10k/images/validation/' val_dir_mask = 'E:/Dataset/Dataset10k/annotations/validation/' dir_checkpoint = 'checkpoints/' # Get list of images and annotations train_images = os.listdir(dir_img) train_masks = os.listdir(dir_mask) train_size = len(train_images) val_images = os.listdir(val_dir_img) val_masks = os.listdir(val_dir_mask) val_size = len(val_images) val_imgs = np.array([read_image(val_dir_img + i) for i in val_images]).astype(np.float32) val_true_masks = np.array( [read_masks(val_dir_mask + i) for i in val_masks]) val = zip(val_imgs, val_true_masks) print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, train_size, val_size, str(save_cp), str(gpu))) # Define optimizer and loss functions optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) criterion = nn.BCELoss() # Start training epochs for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() epoch_loss = 0 for i in range(round(train_size // batch_size)): imgs = train_images[i:i + batch_size] true_masks = train_masks[i:i + batch_size] imgs = np.array([read_image(dir_img + i) for i in imgs]).astype(np.float32) true_masks = np.array( [read_masks(dir_mask + i) for i in true_masks]) imgs = torch.from_numpy(imgs) true_masks = torch.from_numpy(true_masks) print(imgs.size(), true_masks.size()) if gpu: imgs = imgs.cuda() true_masks = true_masks.cuda() masks_pred = net(imgs) print(masks_pred.size()) masks_probs_flat = masks_pred.view(-1) print(masks_probs_flat.size()) true_masks_flat = true_masks.view(-1) print(true_masks_flat.size()) loss = criterion(masks_probs_flat, true_masks_flat) epoch_loss += loss.item() print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size, loss.item())) optimizer.zero_grad() loss.backward() optimizer.step() print('Epoch finished ! Loss: {}'.format(np.mean(epoch_loss))) if 1: val_dice = eval_net(net, val, gpu) print('Validation Dice Coeff: {}'.format(val_dice)) if save_cp: torch.save(net.state_dict(), dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) print('Checkpoint {} saved !'.format(epoch + 1))
def train_nets(gen_net, gen_optimizer, gen_scheduler, args): # if args.dataset == 'Aspect': # train_dataset = AspectDataset(args.train_dir, args) # val_dataset = AspectDataset(args.val_dir, args, validtion_flag=True) if args.dataset == 'IXI': train_dataset = IXIataset(args.train_dir, args) val_dataset = IXIataset(args.val_dir, args, validtion_flag=True) train_loader = DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, num_workers=2, pin_memory=True) val_loader = DataLoader( val_dataset, batch_size=args.batchsize, shuffle=True, num_workers=2, pin_memory=True, drop_last=True ) #shuffle is true just for the diffrent images on tensorboard #TODO: better name for checkpoints dir writer = SummaryWriter(log_dir=dir_checkpoint + '/runs', comment=f'LR_{args.lr}_BS_{args.batchsize}') logging.info(f'''Starting training: Epochs: {args.epochs_n} Batch size: {args.batchsize} Learning rate: {args.lr} Checkpoints: {args.save_cp} Device: {args.device} ''') gen_net.to(device=device) start_epoch = 0 if args.load: checkpoint = torch.load(args.load, map_location=args.device) gen_net.load_state_dict(checkpoint['model_state_dict']) if args.load_scheduler_optimizer: start_epoch = int(checkpoint['epoch']) gen_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) gen_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) logging.info( f'Model, optimizer and scheduler load from {args.load}') else: logging.info(f'Model only load from {args.load}') criterion = netLoss(args) for epoch in range(start_epoch, args.epochs_n): gen_net.train() epoch_loss = 0 progress_img = 0 with tqdm(desc=f'Epoch {epoch + 1}/{args.epochs_n}', unit=' imgs') as pbar: #train for batch in train_loader: masked_Kspaces = batch['masked_Kspaces'] target_Kspace = batch['target_Kspace'] target_img = batch['target_img'] masked_Kspaces = masked_Kspaces.to(device=args.device, dtype=torch.float32) target_Kspace = target_Kspace.to(device=args.device, dtype=torch.float32) target_img = target_img.to(device=args.device, dtype=torch.float32) rec_img, rec_Kspace, F_rec_Kspace = gen_net(masked_Kspaces) FullLoss, ImL2, ImL1, KspaceL2 = criterion.calc( rec_img, rec_Kspace, target_img, target_Kspace) epoch_loss += FullLoss.item() writer.add_scalar('train/FullLoss', FullLoss.item(), epoch) writer.add_scalar('train/ImL2', ImL2.item(), epoch) writer.add_scalar('train/ImL1', ImL1.item(), epoch) writer.add_scalar('train/KspaceL2', KspaceL2.item(), epoch) progress_img += 100 * target_Kspace.shape[0] / len( train_dataset) pbar.set_postfix( **{ 'FullLoss': FullLoss.item(), 'ImL2': ImL2.item(), 'ImL1': ImL1.item(), 'KspaceL2': KspaceL2.item(), 'Prctg of train set': progress_img }) gen_optimizer.zero_grad() FullLoss.backward() #TODO: Do we need this clipping? nn.utils.clip_grad_value_(gen_net.parameters(), 0.1) gen_optimizer.step() pbar.update(target_Kspace.shape[0]) # current batch size # if epoch: writer.add_images('train/Fully_sampled_images', target_img, epoch) writer.add_images('train/rec_images', rec_img, epoch) writer.add_images('train/Kspace_rec_images', F_rec_Kspace, epoch) for tag, value in gen_net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), epoch) writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), epoch) # validation: val_rec_img, val_full_img, val_F_rec_Kspace, val_FullLoss, val_ImL2, val_ImL1, val_KspaceL2, val_PSNR =\ eval_net(gen_net, val_loader, criterion, args.device) gen_scheduler.step(val_FullLoss) writer.add_images('validation/Fully_sampled_images', val_full_img, epoch) writer.add_images('validation/rec_images', val_rec_img, epoch) writer.add_images('validation/Kspace_rec_images', val_F_rec_Kspace, epoch) writer.add_scalar('learning_rate', gen_optimizer.param_groups[0]['lr'], epoch) logging.info( 'Validation full score: {}, ImL2: {}. ImL1: {}, KspaceL2: {}, PSNR: {}' .format(val_FullLoss, val_ImL2, val_ImL1, val_KspaceL2, val_PSNR)) writer.add_scalar('validation/FullLoss', val_FullLoss, epoch) writer.add_scalar('validation/ImL2', val_ImL2, epoch) writer.add_scalar('validation/ImL2', val_ImL2, epoch) writer.add_scalar('validation/ImL1', val_ImL1, epoch) writer.add_scalar('validation/KspaceL2', val_KspaceL2, epoch) writer.add_scalar('validation/PSNR', val_PSNR, epoch) if args.save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save( { 'epoch': epoch, 'model_state_dict': gen_net.state_dict(), 'optimizer_state_dict': gen_optimizer.state_dict(), 'scheduler_state_dict': gen_scheduler.state_dict(), }, dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close()
def train_net(net, device, epochs=5, batch_size=1, lr=0.1, val_percent=0.1, save_cp=True, img_scale=0.5, data_augment=True): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True) global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8) criterion = nn.BCEWithLogitsLoss() # 1 class best_score = 0. for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' assert true_masks.shape[1] == net.n_classes, \ f'Network has been defined with {net.n_classes} output classes, ' \ f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \ 'the masks are loaded correctly.' if data_augment: for i in range(imgs.__len__()): imgs[i], true_masks[i] = my_segmentation_transforms( imgs[i], true_masks[i]) imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.float32) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (len(dataset) // (10 * batch_size)) == 0: val_score = eval_net(net, val_loader, device, n_val) logging.info('Validation Dice Coeff: {}'.format(val_score)) print(" ") print('Validation Dice Coeff: {}'.format(val_score)) if best_score < val_score: torch.save(net.state_dict(), 'BEST.pth') logging.info(f'Best saved !') best_score = val_score if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !')
def train_net(net, epochs=5, batch_size=10, lr=0.1, val_percent=0.05, cp=True, gpu=False, mask_type="depth", half_scale=True): prefix = "/data/chc631/project/" dir_img = prefix + 'data/train/' # use depth map as target if mask_type == "depth": dir_mask = prefix + "data/train_masks_depth_map/" # use color map as target else: dir_mask = prefix + 'data/train_masks/' dir_checkpoint = "/data/chc631/project/data/checkpoints/" + options.dir if not os.path.exists(dir_checkpoint): os.makedirs(dir_checkpoint) ids = get_ids(dir_img) ids = split_ids(ids) iddataset = split_train_val(ids, val_percent) print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, len(iddataset['train']), len(iddataset['val']), str(cp), str(gpu))) N_train = len(iddataset['train']) optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) criterion = scaleInvarLoss() for epoch in range(epochs): net.train() print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) epoch_loss = 0 if half_scale: print("half_scale") train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, scale=0.5) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, scale=0.5) else: train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, scale=1) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, scale=1) # train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask) for i, b in enumerate(batch(train, batch_size)): X = np.array([i[0] for i in b]) y = np.array([i[1] for i in b]) X = torch.FloatTensor(X) y = torch.FloatTensor(y) y = y.unsqueeze( 0) # manually create a channel dimension for conv2d y = y.transpose(0, 1) if gpu: X = Variable(X).cuda() y = Variable(y).cuda() else: X = Variable(X) y = Variable(y) y_pred = net(X) y_pred_flat = y_pred.view(-1) if half_scale: conv_mat = Variable(torch.ones(1, 1, 2, 2)).cuda() y = F.conv2d(y, conv_mat, stride=2) y = torch.squeeze(y) y_flat = y.view(-1) loss = criterion(y_pred_flat, y_flat.float()) epoch_loss += loss.data[0] print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.data[0])) optimizer.zero_grad() loss.backward() optimizer.step() print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) if cp: torch.save(net.state_dict(), dir_checkpoint + "/" + 'CP{}.pth'.format(epoch + 1)) print('Checkpoint {} saved !'.format(epoch + 1)) val_err = eval_net(net, val, gpu, half_scale) print('Validation Error: {}'.format(val_err)) with open(dir_checkpoint + "/ValidationError.txt", 'a') as outfile: outfile.write(str(val_err) + '\n') with open(dir_checkpoint + "/TrainingError.txt", 'a') as outfile: outfile.write(str(epoch_loss / i) + '\n')
def train_net(net, device, epochs=5, batch_size=1, lr=0.001, val_percent=0.1, save_cp=True, img_scale=0.5): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) writer = SummaryWriter( comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss() for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 0.1) optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (n_train // (10 * batch_size)) == 0: for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) val_score = eval_net(net, val_loader, device) scheduler.step(val_score) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) if net.n_classes > 1: logging.info( 'Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info( 'Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close()
def train_net(net, trainset, valset, device, epochs, batch_size, lr, weight_decay, log_save_path): train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) writer = SummaryWriter(log_dir=log_save_path) optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay) scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95) criterion = DiceBCELoss() best_DSC = 0.0 for epoch in range(epochs): logging.info(f'Epoch {epoch + 1}') epoch_loss = 0 epoch_dice = 0 with tqdm(total=len(trainset), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: net.train() imgs = batch['image'] true_masks = batch['mask'] imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.float32) masks_pred = net(imgs) pred = torch.sigmoid(masks_pred) pred = (pred > 0.5).float() loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() epoch_dice += dice_coeff(pred, true_masks).item() optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 5) optimizer.step() pbar.set_postfix(**{'loss (batch)': loss.item()}) pbar.update(imgs.shape[0]) scheduler.step() logging.info('Training loss: {}'.format(epoch_loss / len(train_loader))) writer.add_scalar('Train/loss', epoch_loss / len(train_loader), epoch) logging.info('Training DSC: {}'.format(epoch_dice / len(train_loader))) writer.add_scalar('Train/dice', epoch_dice / len(train_loader), epoch) val_dice, val_loss = eval_net(net, val_loader, device, criterion) logging.info('Validation Loss: {}'.format(val_loss)) writer.add_scalar('Val/loss', val_loss, epoch) logging.info('Validation DSC: {}'.format(val_dice)) writer.add_scalar('Val/dice', val_dice, epoch) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) # writer.add_images('images', imgs, epoch) writer.add_images('masks/true', true_masks, epoch) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, epoch) writer.close()