def predict(net, full_img, device, input_size, mask_way='warp'): ''' :mask_type: Sets the way to obtain the mask. Сan take 'warp' or 'segm' ''' # Preprocess input image: img = BasicDataset.preprocess_img(full_img, input_size) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) net.eval() # Predict: with torch.no_grad(): logits, rec_mask, theta = net.predict( img, warp=True if mask_way == 'warp' else False) if mask_way == 'warp': mask = rec_mask * net.n_classes mask = mask.type(torch.IntTensor).cpu().numpy().astype(np.uint8) elif mask_way == 'segm': mask = preds_to_masks(logits, net.n_classes) else: raise NotImplementedError return mask, theta
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) probs = F.softmax(output, dim=1) probs = probs.squeeze(0) tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ]) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() return full_mask > out_threshold
def getDataSet(inputDir, labelDir, imgScale, valPercent): dataset = BasicDataset(inputDir, labelDir, imgScale) # We use va_percent for validation and the rest for training n_val = int(len(dataset) * valPercent) n_train = len(dataset) - n_val # We randomly split the data to train/validation train, val = random_split(dataset, [n_train, n_val]) return (train, val)
def predict_img(net, full_img, device, input_size): net.eval() img = BasicDataset.preprocess_img(full_img, input_size) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): preds = net(img) masks = preds_to_masks(preds, net.n_classes) # GPU tensor -> CPU numpy return masks
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5, use_dense_crf=False): net.eval() ds = BasicDataset('', '', scale=scale_factor) img = torch.from_numpy(ds.preprocess(full_img)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) tf = transforms.Compose( [ transforms.ToPILImage(), transforms.Resize(full_img.shape[1]), transforms.ToTensor() ] ) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() if use_dense_crf: full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) return full_mask > out_threshold
def prediction(net, imgs, device): net.eval() ds = BasicDataset('data/training/img', 'data/training/full_mask', scale=0.5) img = ds.preprocess(imgs) img = torch.from_numpy(img) img = torch.unsqueeze(img, 0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) probs = torch.sigmoid(output) probs = probs.squeeze(0) tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(imgs.size[1]), transforms.ToTensor() ]) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() return full_mask > 0.5
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy( BasicDataset.preprocess(full_img, scale_factor, False)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) output_seg = output.max(dim=1)[1].unsqueeze(1) output_seg = output_seg.data.cpu().numpy() return output_seg[0, 0, :, :]
def plot_imgs_pred(): """ Funzione :return: """ args = get_plot_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # img = Image.open(args.dir_img) img = tiff.imread(args.dir_img) img = BasicDataset.preprocess(img, scale=args.scale) img = torch.from_numpy(img).type(torch.FloatTensor) if args.model_arch == 'unet': net = UNet(n_channels=4, n_classes=4, bilinear=True) elif args.model_arch == 'icnet': net = ICNet(n_channels=4, n_classes=4, pretrained_base=False) net.load_state_dict(torch.load(args.checkpoint_net, map_location=device)) net.to(device=device) net.eval() img = img.to(device=device, dtype=torch.float32) img = img.unsqueeze(0) with torch.no_grad(): if args.model_arch == 'icnet': mask_pred, pred_sub4, pred_sub8, pred_sub16 = net(img) else: mask_pred = net(img) plt.imshow(img[0][0]) plt.colorbar() plt.savefig(args.dir_output + "original_img.png") plt.clf() for i, c in enumerate(mask_pred): n_classes = c.size(0) classes = range(n_classes) c = torch.sigmoid(c) max_index = torch.max(c, 0).indices for class_index in classes: # Vediamo la predizione jaccard_input = (max_index == class_index).float() plt.imshow(jaccard_input) plt.colorbar() plt.savefig(args.dir_output + f"pred_cls_{class_index}.png") plt.clf()
def show_gt(): batch_size = 1 img_scale = 1 dataset = BasicDataset(dir_img, dir_mask, img_scale) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) for batch in train_loader: img = batch['image'].numpy()[0].transpose(1, 2, 0) mask = batch['mask'].numpy()[0].transpose(1, 2, 0) cv2.imshow("img", img) cv2.imshow("mask", mask) cv2.waitKey()
def fun1(): dir_img = 'data/imgs/' dir_mask = 'data/masks/' img_scale=0.5 val_percent=0.1 batch_size = 1 dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] print(imgs.shape) exit(0)
def predict_img(net, full_img, device, input_size): # Preprocess input image: img = BasicDataset.preprocess_img(full_img, input_size) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) net.eval() # Predict: with torch.no_grad(): mask_pred, mask_proj = net(img) # Tensors to ndarrays: mask = preds_to_masks(mask_pred, net.n_classes) proj = mask_proj * net.n_classes proj = proj.type(torch.IntTensor).cpu().numpy().astype(np.uint8) return mask, proj
def validation_only(net, device, batch_size=1, img_width=0, img_height=0, img_scale=1.0, use_bw=False, standardize=False, compute_statistics=False): load_statstics = not compute_statistics dataset = BasicDataset(dir_img_test, dir_mask_test, img_width, img_height, img_scale, use_bw, standardize=standardize, load_statistics=load_statstics, save_statistics=True) val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) val_score = eval_net(net, val_loader, device) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) else: logging.info('Validation Dice Coeff: {}'.format(val_score))
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) # img_input = transforms.ToPILImage()(img.type(torch.float32)).convert('RGB') # img_input.save('input.jpg') img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) _, probs = torch.max(probs, dim=0) print(f'probs.max():{probs.max()}') probs = probs.unsqueeze(0) * 10 probs = probs.type(torch.float32) # tf = transforms.Compose( # [ # transforms.ToPILImage(), # transforms.Resize(full_img.size[1]), # transforms.ToTensor() # ] # ) tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ]) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() return full_mask > out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.0): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) print(probs) full_mask = probs.cpu().numpy() print(type(full_mask)) print("*******************************************************") """ _, (ax1, ax2, ax3, ax4, ax5,ax6,ax7,ax8,ax9,ax10,ax11,ax12,ax13,ax14) = plt.subplots(1, 14, sharey=True) ax1.imshow(full_mask[0,:,:].squeeze()) ax2.imshow(full_mask[1,:,:].squeeze()) ax3.imshow(full_mask[2,:,:].squeeze()) ax4.imshow(full_mask[3,:,:].squeeze()) ax5.imshow(full_mask[4,:,:].squeeze()) ax6.imshow(full_mask[5, :, :].squeeze()) ax7.imshow(full_mask[6, :, :].squeeze()) ax8.imshow(full_mask[7, :, :].squeeze()) ax9.imshow(full_mask[8, :, :].squeeze()) ax10.imshow(full_mask[9, :, :].squeeze()) ax11.imshow(full_mask[10, :, :].squeeze()) ax12.imshow(full_mask[11, :, :].squeeze()) ax13.imshow(full_mask[12, :, :].squeeze()) ax14.imshow(full_mask[13, :, :].squeeze()) """ print(full_mask) full_mask = np.argmax(full_mask, axis=0) print("--***********************************************") print(full_mask.shape) return full_mask
def predict_img(net, full_img, device, scale_factor=0.5, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) print(output.shape) #if net.n_classes > 1: # probs = F.softmax(output, dim=1) #else: # probs = torch.sigmoid(output) probs_0 = torch.sigmoid(output[:, 0, :, :]) probs_1 = torch.sigmoid(output[:, 1, :, :]) probs_0 = probs_0.squeeze(0) probs_1 = probs_1.squeeze(0) tf = transforms.Compose( [ transforms.ToPILImage(), #transforms.Resize(full_img.width),#size[1]), transforms.Resize((full_img.height, full_img.width)), transforms.ToTensor() ] ) probs_0 = tf(probs_0.cpu()) probs_1 = tf(probs_1.cpu()) mask_0 = probs_0.squeeze().cpu().numpy() mask_1 = probs_1.squeeze().cpu().numpy() full_mask = np.array([mask_0, mask_1])#probs.squeeze().cpu().numpy() return full_mask# > out_threshold
def predict_img(net, full_img1, full_img2, full_img3, full_img4, full_img5, device, scale_factor=0.267, out_threshold=0.6): net.eval() img = torch.from_numpy( BasicDataset.preprocess(full_img1, full_img2, full_img3, full_img4, full_img5, scale_factor)) img_pro = (img).squeeze(0).numpy() img_show = sitk.GetImageFromArray(img_pro) #sitk.WriteImage(img_show, './data/pred/scale0.2_002input.nii') img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) max_out = output.cpu().numpy().max() if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) max0 = probs.cpu().numpy().max() probs = output.squeeze() tf = transforms.Compose([ transforms.ToPILImage(), #transforms.Resize((481,481,481)), transforms.ToTensor() ]) #probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() return full_mask
def test_net(net, device, batch_size=4, scale=512, threshold=0.5): dataset = BasicDataset(dir_img, dir_mask, 512, False, 5) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) tm = TimeManager() val_score, precision, recall = eval_net(net, loader, device, threshold) if net.n_classes > 1: print('Validation cross entropy:', val_score) else: print('Validation Dice Coeff:', val_score) print('Validation Precision:', precision) print('Validation Recall:', recall) tm.show()
def inference_one(net, image, device): net.eval() img = torch.from_numpy(BasicDataset.preprocess(image, cfg.scale)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if cfg.deepsupervision: output = output[-1] if cfg.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) # C x H x W tf = transforms.Compose( [ transforms.ToPILImage(), transforms.Resize((image.size[1], image.size[0])), transforms.ToTensor() ] ) if cfg.n_classes == 1: probs = tf(probs.cpu()) mask = probs.squeeze().cpu().numpy() return mask > cfg.out_threshold else: masks = [] for prob in probs: prob = tf(prob.cpu()) mask = prob.squeeze().cpu().numpy() mask = mask > cfg.out_threshold masks.append(mask) return masks
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): # output = net(img) centerlines, points = net(img) if net.n_classes > 1: probs1 = F.softmax(centerlines, dim=1) probs2 = F.softmax(points, dim=1) else: probs1 = torch.sigmoid(centerlines) probs2 = torch.sigmoid(points) probs1 = probs1.squeeze(0) probs2 = probs2.squeeze(0) tf = transforms.Compose( [ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ] ) probs1 = tf(probs1.cpu()) probs2 = tf(probs2.cpu()) full_centerlines = probs1.squeeze().cpu().numpy() full_points = probs2.squeeze().cpu().numpy() return full_centerlines > out_threshold, full_points > out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.0): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) print(img.shape) img = img.unsqueeze(0) print(img.shape) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) print(output) if net.n_classes > 1: probs = F.softmax(output, dim=1) print(probs.shape) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) print("probsss", probs) tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ]) probs = tf(probs.cpu()) sum1 = 0 full_mask = probs.squeeze().cpu().numpy() """ for i in range(0,375): sum1 += (full_mask[0][0][i]) print(sum1) """ print("*******", full_mask) print(out_threshold) return np.argmax(full_mask, axis=0)
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) # if net.n_classes > 1: # probs = F.softmax(output, dim=1) # else: # probs = torch.sigmoid(output) probs = torch.sigmoid(output) probs = probs.squeeze(0) # tf = transforms.Compose( # [ # transforms.ToPILImage(), # transforms.Resize(full_img.size[1]), # transforms.ToTensor() # ] # ) # # probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() import matplotlib.pyplot as plt plt.figure() plt.imshow(full_mask[2:,:,:].argmax(0)), plt.colorbar() plt.show() return full_mask #> out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) tf = transforms.Compose( [ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ] ) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() result = np.zeros(full_mask.shape, dtype=np.bool) for i, thres in enumerate(out_threshold): result[i] = full_mask[i] > out_threshold[i] return result
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5, outf=None): net.eval() transform = transforms.Compose([transforms.Resize((128, 128))]) img = torch.from_numpy( BasicDataset.preprocess(full_img, scale_factor, transform)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) tf = transforms.Compose([ transforms.ToPILImage(), #transforms.Resize(full_img.size[1]), transforms.ToTensor() ]) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() outfn = outf.split('/') outtrain_fn = "./data/predicted/train128_" + outfn[3] save_image(img, outtrain_fn) return full_mask > out_threshold
def predict_img(net, full_img, device, file_name, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img)['out'] if net.num_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) save_array = probs.cpu().numpy() mat_prob = np.reshape(save_array, [300, 300]) save_fn = 'D:/users/otis/MedicalImage_Project02_Segmentation/MedicalImage_Project02_Segmentation/private_data_10/private_data_10/Results' + file_name[: -4] + '_prob.mat' sio.savemat(save_fn, {'array': mat_prob}) tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ]) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() return full_mask > out_threshold
def infer(args, unlabeled, ckpt_file): # Load the last best model traindataset = BasicDataset( args["TRAINIMAGEDATA_DIR"], args["TRAINLABEL_DIRECTORY"], img_scale ) unlableddataset = Subset(traindataset, unlabeled) unlabeled_loader = DataLoader( unlableddataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, ) predix = 0 predictions = {} net = UNet(n_channels=3, n_classes=1, bilinear=True) net.to(device=device) net.load_state_dict(torch.load(os.path.join(args["EXPT_DIR"] + ckpt_file))) net.eval() with tqdm(total=n_val, desc="Validation round", unit="batch", leave=False) as pbar: for batch in val_loader: imgs, true_masks = batch["image"], batch["mask"] imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=mask_type) with torch.no_grad(): mask_pred = net(imgs) for ix, logit in enumerate(maskpred): predictions[predix] = logit.cpu().numpy() predix += 1 pbar.update() return {"outputs": predictions}
def train_net(net, device, epochs=5, batch_size=1, lr=0.001, val_percent=0.1, save_cp=True, img_scale=0.5): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) writer = SummaryWriter( comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss() for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 0.1) optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (n_train // (10 * batch_size)) == 0: for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) val_score = eval_net(net, val_loader, device) scheduler.step(val_score) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) if net.n_classes > 1: logging.info( 'Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info( 'Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close()
def train_net(net, device, epochs=5, batch_size=1, lr=0.1, val_percent=0.1, save_cp=True, img_scale=0.5, data_augment=True): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True) global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8) criterion = nn.BCEWithLogitsLoss() # 1 class best_score = 0. for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' assert true_masks.shape[1] == net.n_classes, \ f'Network has been defined with {net.n_classes} output classes, ' \ f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \ 'the masks are loaded correctly.' if data_augment: for i in range(imgs.__len__()): imgs[i], true_masks[i] = my_segmentation_transforms( imgs[i], true_masks[i]) imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.float32) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (len(dataset) // (10 * batch_size)) == 0: val_score = eval_net(net, val_loader, device, n_val) logging.info('Validation Dice Coeff: {}'.format(val_score)) print(" ") print('Validation Dice Coeff: {}'.format(val_score)) if best_score < val_score: torch.save(net.state_dict(), 'BEST.pth') logging.info(f'Best saved !') best_score = val_score if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !')
def main(): # network = 'deeplabv3p' # save_model_path = "./model_weights/" + network + "_" # model_path = "./model_weights/" + network + "_0_6000" data_dir = '' val_percent = .1 epochs = 9 kwargs = { 'num_workers': 4, 'pin_memory': True } if torch.cuda.is_available() else {} training_dataset = LaneDataset( "~/workspace/myDL/CV/week8/Lane_Segmentation_pytorch/data_list/train.csv", transform=transforms.Compose( [ImageAug(), DeformAug(), ScaleAug(), CutOut(32, 0.5), ToTensor()])) training_data_batch = DataLoader(training_dataset, batch_size=2, shuffle=True, drop_last=True, **kwargs) dataset = BasicDataset(data_dir, img_size=cfg.IMG_SIZE, crop_offset=cfg.crop_offset) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=cfg.batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=cfg.batch_size, shuffle=False, num_workers=8, pin_memory=True) model = unet_base(cfg.num_classes, cfg.IMG_SIZE) model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=cfg.base_lr, betas=(0.9, 0.99)) bce_criterion = nn.BCEWithLogitsLoss() dice_criterion = MulticlassDiceLoss() model.train() epoch_loss = 0 dataprocess = tqdm(training_data_batch) for batch_item in dataprocess: image, mask = batch_item['image'], batch_item['mask'] if torch.cuda.is_available(): image, mask = image.cuda(), mask.cuda() image = image.to(torch.float32).requires_grad_() mask = mask.to(torch.float32).requires_grad_() masks_pred = model(image) masks_pred = torch.argmax(masks_pred, dim=1) masks_pred = masks_pred.to(torch.float32) mask = mask.to(torch.float32) # print('mask_pred:', masks_pred) # print('mask:', mask) loss = bce_criterion(masks_pred, mask) + dice_criterion( masks_pred, mask) epoch_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step()
def train_net(net, device, epochs=5, batch_size=1, lr=0.1, val_percent=0.1, save_cp=True, img_scale=0.5): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) #writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss() best_score = 0 for epoch in range(epochs): net.train() epoch_loss = 0 start = time.time() with tqdm(total=n_train, desc=f'Epoch {epoch}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() #writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() pbar.update(imgs.shape[0]) cost_time = time.time() - start logging.info(f"{epoch} loss: {epoch_loss:.5f} time {cost_time:.3f}s") val_score = eval_net(net, val_loader, device, n_val) if net.n_classes > 1: logging.info('Validation cross entropy: {:.5f}'.format(val_score)) #writer.add_scalar('Loss/test', val_score, global_step) else: logging.info('Validation Dice Coeff: {:.5f}'.format(val_score)) #writer.add_scalar('Dice/test', val_score, global_step) #writer.add_images('images', imgs, global_step) # if net.n_classes == 1: # writer.add_images('masks/true', true_masks, global_step) # writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if val_score > best_score: torch.save(net.state_dict(), log_dir + '/best.pth') best_score = val_score logging.info(f'best improved to {val_score:.5f}') torch.save(net.state_dict(), log_dir + "/latest.pth")
def train_net(net, device, epochs=100, batch_size=1, lr=0.1, val_percent=0.2, save_cp=True, img_scale=1): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) gene_eval_data(val_loader, dir='./data/val/') writer = SummaryWriter( comment='LR_{}_BS_{}_SCALE_{}'.format(lr, batch_size, img_scale)) global_step = 0 logging.info('''Starting training: Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints: {} Device: {} Images scaling: {} '''.format(epochs, batch_size, lr, n_train, n_val, save_cp, device.type, img_scale)) # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min' if net.n_classes > 1 else 'max', factor=0.5, patience=20) criterion = dice_loss # criterion = nn.BCELoss() last_loss = 9999 last_val_score = 0 for epoch in range(epochs): net.train() epoch_loss = 0 step = 0 mybatch_size = 4 with tqdm(total=n_train, desc='Epoch {}/{}'.format(epoch + 1, epochs), unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels,\ 'Network has been defined with {} input channels, '.format(net.n_channels)+\ 'but loaded images have {} channels. Please check that '.format(imgs.shape[1])+\ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 0.1) global_step += 1 writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) step += 1 if step % mybatch_size == 0: optimizer.step() optimizer.zero_grad() step = 0 pbar.update(imgs.shape[0]) # if global_step % (len(dataset) // ( 2* batch_size)) == 0: for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) val_score = eval_net(net, val_loader, device) scheduler.step(val_score) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info('Train Loss: {} Validation Dice Coeff: {} '.format( epoch_loss / n_train, val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.3, global_step) if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass if last_loss > epoch_loss or last_val_score < val_score: last_loss = min(last_loss, epoch_loss) last_val_score = max(last_val_score, val_score) # torch.save(net.state_dict(), torch.save( net, dir_checkpoint + 'CP_epoch{}Trainloss{}ValDice{}.pt'.format( epoch + 1, epoch_loss / n_train, val_score)) logging.info('Checkpoint {} saved !'.format(epoch + 1) + ' CP_epoch{}Trainloss{}ValDice{}.pt'.format( epoch + 1, epoch_loss / n_train, val_score)) writer.close()