def get_model(model_path, model_type): """ :param model_path: :param model_type: 'UNet', 'UNet11', 'UNet16', 'AlbuNet34' :return: """ num_classes = 1 if model_type == 'UNet11': model = UNet11(num_classes=num_classes) elif model_type == 'UNet16': model = UNet16(num_classes=num_classes) elif model_type == 'AlbuNet34': model = AlbuNet34(num_classes=num_classes) elif model_type == 'UNet': model = UNet(num_classes=num_classes) else: model = UNet(num_classes=num_classes) state = torch.load(str(model_path)) state = { key.replace('module.', ''): value for key, value in state['model'].items() } model.load_state_dict(state) if torch.cuda.is_available(): return model.cuda() model.eval() return model
def test(opt, log_dir, generator=None): device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") if generator == None: generator = UNet(opt.sample_num, opt.channels, opt.batch_size, opt.alpha) checkpoint = torch.load(opt.load_model, map_location=device) generator.load_state_dict(checkpoint['g_state_dict']) del checkpoint torch.cuda.empty_cache() generator.to(device) generator.eval() dataloader = torch.utils.data.DataLoader(MyDataset_test(opt), opt.batch_size, shuffle=True, num_workers=0) for i, (imgs, filename) in enumerate(dataloader): with torch.no_grad(): test_img = generator(imgs.to(device)) filename = filename[0].split('/')[-1] filename = "test/" + filename + '.png' test_img = convert_im(test_img, os.path.join(log_dir, filename), nrow=5, normalize=True, save_im=True)
def test_moving(opt, log_dir, generator=None): device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") if generator == None: generator = UNet(opt.sample_num, opt.channels, opt.batch_size, opt.alpha) checkpoint = torch.load(opt.load_model, map_location=device) generator.load_state_dict(checkpoint['g_state_dict']) del checkpoint torch.cuda.empty_cache() generator.to(device) generator.eval() dataloader = torch.utils.data.DataLoader( MyDataset_test_moving(opt), opt.batch_size, shuffle=True, num_workers=opt.num_workers_dataloader) for i, (imgs, filename) in enumerate(dataloader): with torch.no_grad(): filename = filename[0].split('/')[-1] for k in range(len(imgs)): test_img = generator(imgs[k].to(device)) folder_path = os.path.join(log_dir, "test/%s" % filename) os.makedirs(folder_path, exist_ok=True) filename_ = filename + '_' + str(k) + '.png' test_img = convert_im(test_img, os.path.join(folder_path, filename_), nrow=5, normalize=True, save_im=True)
def get_model(model_path, model_type, num_classes): """ :param model_path: :param model_type: 'UNet', 'UNet16', 'UNet11', 'LinkNet34', :param problem_type: 'binary', 'parts', 'instruments' :return: """ if model_type == 'UNet': model = UNet(num_classes=num_classes) else: model_name = model_list[model_type] model = model_name(num_classes=num_classes) # print(model) state = torch.load(str(model_path)) state = { key.replace('module.', ''): value for key, value in state['model'].items() } model.load_state_dict(state) if torch.cuda.is_available(): return model.cuda() model.eval() return model
def export2caffe(weights, num_classes, img_size): model = UNet(num_classes) weights = torch.load(weights, map_location='cpu') model.load_state_dict(weights['model']) model.eval() fuse(model) name = 'DeepLabV3Plus' dummy_input = torch.ones([1, 3, img_size[1], img_size[0]]) pytorch2caffe.trans_net(model, dummy_input, name) pytorch2caffe.save_prototxt('{}.prototxt'.format(name)) pytorch2caffe.save_caffemodel('{}.caffemodel'.format(name))
def test(weights_path): # Get all images in train set image_names = os.listdir('dataset/train/images/') image_names = [name for name in image_names if name.endswith(('.jpg', '.JPG', '.png'))] # Initialize model and transfer to device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = UNet() model = model.to(device) model.eval() # Load weights model.load_state_dict(torch.load(weights_path, map_location=device)) # Misc info img_size = 512 # Predict on images for image_name in tqdm(image_names): # Load image, prepare for inference img = cv2.imread(os.path.join('dataset/train/images/', image_name)) img_torch = prepare_image(img, img_size) with torch.no_grad(): # Get predictions for image pred_egg_mask, pred_pan_mask = model(img_torch) # Threshold by 0.5 pred_egg_mask = (torch.sigmoid(pred_egg_mask) >= 0.5).type(pred_egg_mask.dtype) pred_pan_mask = (torch.sigmoid(pred_pan_mask) >= 0.5).type(pred_pan_mask.dtype) pred_egg_mask, pred_pan_mask = pred_egg_mask.cpu().detach().numpy(), pred_pan_mask.cpu().detach().numpy() # Resize masks back to original shape pred_egg_mask, pred_pan_mask = pred_egg_mask[0][0] * 256, pred_pan_mask[0][0] * 256 pred_egg_mask, pred_pan_mask = postprocess_masks(img, pred_egg_mask, pred_pan_mask) cv2.imwrite('test_vis/' + image_name[:-4] + '_egg' + image_name[-4:], pred_egg_mask) cv2.imwrite('test_vis/' + image_name[:-4] + '_pan' + image_name[-4:], pred_pan_mask) cv2.imwrite('test_vis/' + image_name, img)
def test(device, model_path, dataset_path, out_path): """ Tests the network on the dataset_path """ network = UNet(1, 3).to(device) if os.path.exists(model_path): network.load_state_dict(torch.load(model_path)) dataset = GrayDataset(dataset_path, transform=val_transform) loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=cpu_count()) with torch.no_grad(): network.eval() for i, gray in enumerate(tqdm.tqdm(loader, desc="Testing", leave=False)): gray = gray.to(device) pred_color = network(gray) result = F.to_pil_image((pred_color.cpu().squeeze() * 0.5) + 0.5) result.save(os.path.join(out_path, "{:06d}.png".format(i)))
def get_model(model_path, model_type='unet11', problem_type='binary'): """ :param model_path: :param model_type: 'UNet', 'UNet16', 'UNet11', 'LinkNet34' :param problem_type: 'binary', 'parts', 'instruments' :return: """ if problem_type == 'binary': num_classes = 1 elif problem_type == 'parts': num_classes = 4 elif problem_type == 'instruments': num_classes = 8 if model_type == 'UNet16': model = UNet16(num_classes=num_classes) elif model_type == 'UNet11': model = UNet11(num_classes=num_classes) elif model_type == 'LinkNet34': model = LinkNet34(num_classes=num_classes) elif model_type == 'UNet': model = UNet(num_classes=num_classes) elif model_type == 'DLinkNet': model = D_LinkNet34(num_classes=num_classes, pretrained=True) state = torch.load(str(model_path)) state = {key.replace('module.', ''): value for key, value in state['model'].items()} model.load_state_dict(state) if torch.cuda.is_available(): return model.cuda() model.eval() return model
def test(device, gen_model, fake_dataset_path, out_dir): """tests a gan""" print("Test a gan") val_transform = tv.transforms.Compose([ tv.transforms.Resize((224, 224)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, ), (0.5, )) ]) fakedataset = GrayDataset(fake_dataset_path, transform=val_transform) fakeloader = torch.utils.data.DataLoader(fakedataset, batch_size=1, shuffle=False, num_workers=cpu_count()) generator = UNet(1, 3).to(device) if os.path.exists(gen_model): generator.load_state_dict(torch.load(gen_model)) with torch.no_grad(): generator.eval() for i, fake_data in enumerate(tqdm.tqdm(fakeloader)): fake_data = fake_data.to(device) fake = generator(fake_data) fake_im = fake.squeeze().cpu() * 0.5 + 0.5 fake_im = tv.transforms.functional.to_pil_image(fake_im) fake_im.save(os.path.join(out_dir, "{:06d}.png".format(i)))
def inference(img_dir='data/samples', img_size=256, output_dir='outputs', weights='weights/best_miou.pt', unet=False): os.makedirs(output_dir, exist_ok=True) if unet: model = UNet(30) else: model = DeepLabV3Plus(30) model = model.to(device) state_dict = torch.load(weights, map_location=device) model.load_state_dict(state_dict['model']) model.eval() names = [ n for n in os.listdir(img_dir) if os.path.splitext(n)[1] in ['.jpg', '.jpeg', '.png', '.tiff'] ] with torch.no_grad(): for name in tqdm(names): path = os.path.join(img_dir, name) img = cv2.imread(path) img_shape = img.shape h = (img.shape[0] / max(img.shape[:2]) * img_size) // 32 w = (img.shape[1] / max(img.shape[:2]) * img_size) // 32 img = cv2.resize(img, (int(w * 32), int(h * 32))) img = img[:, :, ::-1] img = img.transpose(2, 0, 1) img = torch.FloatTensor([img], device=device) / 255. output = model(img)[0].cpu().numpy().transpose(1, 2, 0) output = cv2.resize(output, (img_shape[1], img_shape[0])) output = output.argmax(2) seg = np.zeros(img_shape, dtype=np.uint8) for ci, color in enumerate(VOC_COLORMAP): seg[output == ci] = color cv2.imwrite(os.path.join(output_dir, name), seg)
parser = argparse.ArgumentParser() parser.add_argument('--model', required=True) parser.add_argument('--input', required=True) parser.add_argument('--output', default='predicted.jpg') parser.add_argument('--output-dir', default='logs') opt = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net = UNet(3, 1) if device == torch.device('cuda'): net = nn.DataParallel(net) net.to(device=device) net.load_state_dict(torch.load(os.path.expanduser(opt.model))) net.eval() # image image = Image.open(os.path.expanduser(opt.input)) image = CarvanaDatasetTransforms([256, 512]).transform(image) image.to(device) mask = net(image.unsqueeze(0))[0] mask = torch.sigmoid(mask) mask = mask.squeeze(0).cpu().detach().numpy() mask = mask > 0.5 mask = Image.fromarray((mask * 255).astype(np.uint8)) mask.save(os.path.join(opt.output_dir, opt.output))
test_file = config_params['test_images_txt'] input_size = config_params['input_size'] num_channels = config_params['num_channels'] n_classes = config_params['n_classes'] bilinear = config_params['bilinear'] # torch.manual_seed(config_params['seed']) test_set = NucleusTestDataset(test_file, input_size) test_loader = DataLoader(test_set, batch_size=1, sampler=RandomSampler(test_set)) # Inference device device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load Model model = UNet(n_channels=num_channels, n_classes=n_classes, bilinear=bilinear).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # Threshold for prediction threshold = float(args.threshold) # Get test image img, idx = next(iter(test_loader)) pred = predict_mask(model, img, threshold, device) # Visualise Prediction visualize(img, pred)
def predict(): net = UNet(n_channels=1, n_classes=1) net.eval() # 将多GPU模型加载为CPU模型 if opt.load_model_path: checkpoint = t.load(opt.load_model_path) state_dict = checkpoint['net'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v net.load_state_dict(new_state_dict) # 加载模型 print('加载预训练模型{}'.format(opt.load_model_path)) if opt.use_gpu: net.cuda() test_data = NodeDataSet(test=True) test_dataloader = DataLoader(test_data, opt.test_batch_size, shuffle=False, num_workers=opt.num_workers) for ii, full_img in enumerate(test_dataloader): img_test = full_img[0][0].unsqueeze( 0) # 第一个[0] 取 原图像的一个batch,第二个[0]指batch为1 if opt.use_gpu: img_test = img_test.cuda() with t.no_grad(): # pytorch0.4版本写法 output = net(img_test) probs = t.sigmoid(output).squeeze(0) full_mask = probs.squeeze().cpu().numpy() # ===========================================下面方法可能未考虑 一通道图像 # if opt.use_dense_crf: # full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) mask = full_mask > opt.out_threshold # 预测mask值都太小,最大0.01 # # 可视化1 # plt.imsave(opt.save_test_dir+str(10000+ii)+'full_img.jpg', full_img[0][0].squeeze(0),cmap = cm.gray) #保存原图 # plt.imsave(opt.save_test_dir+str(10000+ii)+'mask.jpg', mask,cmap = cm.gray) #保存mask # plt.imsave(opt.save_test_dir+str(10000+ii)+'full_mask.jpg', full_img[0][0].squeeze(0).squeeze(0).numpy() * mask,cmap = cm.gray) #保存mask之后的原图 # 可视化2 # # 多子图显示原图和mask # plt.subplot(1,3,1) # plt.title('origin') # plt.imshow(full_img[0][0].squeeze(0),cmap='Greys_r') # # plt.subplot(1, 3, 2) # plt.title('mask') # plt.imshow(mask,cmap='Greys_r') # # plt.subplot(1, 3, 3) # plt.title('origin_after_mask') # plt.imshow( full_img[0][0].squeeze(0).squeeze(0).numpy() * mask,cmap='Greys_r') # # plt.show() # 保存mask为npy np.save('/home/bobo/data/test/test8/' + str(10000 + ii) + '_mask.npy', mask) print('测试完毕')
def train(): t.cuda.set_device(1) # n_channels:医学影像为一通道灰度图 n_classes:二分类 net = UNet(n_channels=1, n_classes=1) optimizer = t.optim.SGD(net.parameters(), lr=opt.learning_rate, momentum=0.9, weight_decay=0.0005) criterion = t.nn.BCELoss() # 二进制交叉熵(适合mask占据图像面积较大的场景) start_epoch = 0 if opt.load_model_path: checkpoint = t.load(opt.load_model_path) # 加载多GPU模型参数到 单模型上 state_dict = checkpoint['net'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v net.load_state_dict(new_state_dict) # 加载模型 optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器 start_epoch = checkpoint['epoch'] # 加载训练批次 # 学习率每当到达milestones值则更新参数 if start_epoch == 0: scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=0.1, last_epoch=-1) # 默认为-1 print('从头训练 ,学习率为{}'.format(optimizer.param_groups[0]['lr'])) else: scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=0.1, last_epoch=start_epoch) print('加载预训练模型{}并从{}轮开始训练,学习率为{}'.format( opt.load_model_path, start_epoch, optimizer.param_groups[0]['lr'])) # 网络转移到GPU上 if opt.use_gpu: net = t.nn.DataParallel(net, device_ids=opt.device_ids) # 模型转为GPU并行 net.cuda() cudnn.benchmark = True # 定义可视化对象 vis = Visualizer(opt.env) train_data = NodeDataSet(train=True) val_data = NodeDataSet(val=True) test_data = NodeDataSet(test=True) # 数据集加载器 train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers) val_dataloader = DataLoader(val_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers) test_dataloader = DataLoader(test_data, opt.test_batch_size, shuffle=False, num_workers=opt.num_workers) for epoch in range(opt.max_epoch - start_epoch): print('开始 epoch {}/{}.'.format(start_epoch + epoch + 1, opt.max_epoch)) epoch_loss = 0 # 每轮判断是否更新学习率 scheduler.step() # 迭代数据集加载器 for ii, (img, mask) in enumerate( train_dataloader): # pytorch0.4写法,不再将tensor封装为Variable # 将数据转到GPU if opt.use_gpu: img = img.cuda() true_masks = mask.cuda() masks_pred = net(img) # 经过sigmoid masks_probs = t.sigmoid(masks_pred) # 损失 = 二进制交叉熵损失 + dice损失 loss = criterion(masks_probs.view(-1), true_masks.view(-1)) # 加入dice损失 if opt.use_dice_loss: loss += dice_loss(masks_probs, true_masks) epoch_loss += loss.item() if ii % 2 == 0: vis.plot('训练集loss', loss.item()) # 优化器梯度清零 optimizer.zero_grad() # 反向传播 loss.backward() # 更新参数 optimizer.step() # 当前时刻的一些信息 vis.log("epoch:{epoch},lr:{lr},loss:{loss}".format( epoch=epoch, loss=loss.item(), lr=optimizer.param_groups[0]['lr'])) vis.plot('每轮epoch的loss均值', epoch_loss / ii) # 保存模型、优化器、当前轮次等 state = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } t.save(state, opt.checkpoint_root + '{}_unet.pth'.format(epoch)) # ============验证=================== net.eval() # 评价函数:Dice系数 Dice距离用于度量两个集合的相似性 tot = 0 for jj, (img_val, mask_val) in enumerate(val_dataloader): img_val = img_val true_mask_val = mask_val if opt.use_gpu: img_val = img_val.cuda() true_mask_val = true_mask_val.cuda() mask_pred = net(img_val) mask_pred = (t.sigmoid(mask_pred) > 0.5).float() # 阈值为0.5 # 评价函数:Dice系数 Dice距离用于度量两个集合的相似性 tot += dice_loss(mask_pred, true_mask_val).item() val_dice = tot / jj vis.plot('验证集 Dice损失', val_dice) # ============验证召回率=================== # 每10轮验证一次测试集召回率 if epoch % 10 == 0: result_test = [] for kk, (img_test, mask_test) in enumerate(test_dataloader): # 测试 unet分割能力,故 不使用真值mask if opt.use_gpu: img_test = img_test.cuda() mask_pred_test = net(img_test) # [1,1,512,512] probs = t.sigmoid(mask_pred_test).squeeze().squeeze().cpu( ).detach().numpy() # [512,512] mask = probs > opt.out_threshold result_test.append(mask) # 得到 测试集所有预测掩码,计算二维召回率 vis.plot('测试集二维召回率', getRecall(result_test).getResult()) net.train()
class Runner(object): def __init__(self, hparams, train_size: int, class_weight: Optional[Tensor] = None): # model, criterion, and prediction self.model = UNet(ch_in=2, ch_out=1, **hparams.model) self.sigmoid = torch.nn.Sigmoid() self.criterion = torch.nn.BCEWithLogitsLoss(reduction='none') self.class_weight = class_weight # for prediction self.frame2time = hparams.hop_size / hparams.sample_rate self.T_6s = round(6 / self.frame2time) - 1 self.T_12s = round(12 / self.frame2time) - 1 self.metrics = ('precision', 'recall', 'F1') # optimizer and scheduler self.optimizer = AdamW( self.model.parameters(), lr=hparams.learning_rate, weight_decay=hparams.weight_decay, ) self.scheduler = CosineLRWithRestarts(self.optimizer, batch_size=hparams.batch_size, epoch_size=train_size, **hparams.scheduler) self.scheduler.step() self.f1_last_restart = -1 # device device_for_summary = self._init_device(hparams.device, hparams.out_device) # summary self.writer = SummaryWriter(logdir=hparams.logdir) path_summary = Path(self.writer.logdir, 'summary.txt') if not path_summary.exists(): print_to_file(path_summary, summary, (self.model, (2, 128, 16 * hparams.model['stride'][1]**4)), dict(device=device_for_summary)) # save hyperparameters path_hparam = Path(self.writer.logdir, 'hparams.txt') if not path_hparam.exists(): with path_hparam.open('w') as f: for var in vars(hparams): value = getattr(hparams, var) print(f'{var}: {value}', file=f) def _init_device(self, device, out_device) -> str: if device == 'cpu': self.device = torch.device('cpu') self.out_device = torch.device('cpu') self.str_device = 'cpu' return 'cpu' # device type if type(device) == int: device = [device] elif type(device) == str: device = [int(device[-1])] else: # sequence of devices if type(device[0]) == int: device = device else: device = [int(d[-1]) for d in device] # out_device type if type(out_device) == int: out_device = torch.device(f'cuda:{out_device}') else: out_device = torch.device(out_device) self.device = torch.device(f'cuda:{device[0]}') self.out_device = out_device if len(device) > 1: self.model = nn.DataParallel(self.model, device_ids=device, output_device=out_device) self.str_device = ', '.join([f'cuda:{d}' for d in device]) else: self.str_device = str(self.device) self.model.cuda(device[0]) self.criterion.cuda(out_device) if self.sigmoid: self.sigmoid.cuda(device[0]) torch.cuda.set_device(device[0]) return 'cuda' def calc_loss(self, y: Tensor, out: Tensor, Ts: Union[List[int], int]) -> Tensor: """ :param y: (B, T) or (T,) :param out: (B, T) or (T,) :param Ts: length B list or int :return: """ assert self.class_weight is not None weight = (y > 0).float() * self.class_weight[1].item() weight += (y == 0).float() * self.class_weight[0].item() if y.dim() == 1: # if batch_size == 1 y = (y, ) out = (out, ) weight = (weight, ) Ts = (Ts, ) loss = torch.zeros(1, device=self.out_device) for ii, T in enumerate(Ts): loss_no_red = self.criterion(out[ii:ii + 1, ..., :T], y[ii:ii + 1, :T]) loss += (loss_no_red * weight[ii:ii + 1, :T]).sum() / T return loss def predict(self, out_np: ndarray, Ts: Union[List[int], int]) \ -> Tuple[List[ndarray], List]: """ peak-picking prediction :param out_np: (B, T) or (T,) :param Ts: length B list or int :return: boundaries, thresholds boundaries: length B list of boundary interval ndarrays thresholds: length B list of threshold values """ if out_np.ndim == 1: # if batch_size == 1 out_np = (out_np, ) Ts = (Ts, ) boundaries = [] thresholds = [] for item, T in zip(out_np, Ts): candid_idx = [] for idx in range(1, T - 1): i_first = max(idx - self.T_6s, 0) i_last = min(idx + self.T_6s + 1, T) if item[idx] >= np.amax(item[i_first:i_last]): candid_idx.append(idx) boundary_idx = [] threshold = np.mean(item[candid_idx]) for idx in candid_idx: if item[idx] > threshold: boundary_idx.append(idx) boundary_interval = np.array( [[0] + boundary_idx, boundary_idx + [T]], dtype=np.float64).T boundary_interval *= self.frame2time boundaries.append(boundary_interval) thresholds.append(threshold) return boundaries, thresholds @staticmethod def evaluate(reference: Union[List[ndarray], ndarray], prediction: Union[List[ndarray], ndarray]): """ :param reference: length B list of ndarray or just ndarray :param prediction: length B list of ndarray or just ndarray :return: (3,) ndarray """ if isinstance(reference, ndarray): # if batch_size == 1 reference = (reference, ) result = np.zeros(3) for item_truth, item_pred in zip(reference, prediction): mir_result = mir_eval.segment.detection(item_truth, item_pred, trim=True) result += np.array(mir_result) return result # Running model for train, test and validation. def run(self, dataloader, mode: str, epoch: int): self.model.train() if mode == 'train' else self.model.eval() if mode == 'test': state_dict = torch.load(Path(self.writer.logdir, f'{epoch}.pt')) if isinstance(self.model, nn.DataParallel): self.model.module.load_state_dict(state_dict) else: self.model.load_state_dict(state_dict) path_test_result = Path(self.writer.logdir, f'test_{epoch}') os.makedirs(path_test_result, exist_ok=True) else: path_test_result = None avg_loss = 0. avg_eval = 0. all_thresholds = dict() print() pbar = tqdm(dataloader, desc=f'{mode} {epoch:3d}', postfix='-', dynamic_ncols=True) for i_batch, (x, y, intervals, Ts, ids) in enumerate(pbar): # data n_batch = len(Ts) if hasattr(Ts, 'len') else 1 x = x.to(self.device) # B, C, F, T x = dataloader.dataset.normalization.normalize_(x) y = y.to(self.out_device) # B, T # forward out = self.model(x) # B, C, 1, T out = out[..., 0, 0, :] # B, T # loss if mode != 'test': if mode == 'valid': with torch.autograd.detect_anomaly(): loss = self.calc_loss(y, out, Ts) else: loss = self.calc_loss(y, out, Ts) else: loss = 0 out_np = self.sigmoid(out).detach().cpu().numpy() prediction, thresholds = self.predict(out_np, Ts) eval_result = self.evaluate(intervals, prediction) if mode == 'train': # backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.batch_step() loss = loss.item() elif mode == 'valid': loss = loss.item() if i_batch == 0: # save only the 0-th data id_0, T_0 = ids[0], Ts[0] out_np_0 = out_np[0, :T_0] pred_0, truth_0 = prediction[0][1:, 0], intervals[0][1:, 0] t_axis = np.arange(T_0) * self.frame2time fig = draw_lineplot(t_axis, out_np_0, pred_0, truth_0, id_0) self.writer.add_figure(f'{mode}/out', fig, epoch) np.save(Path(self.writer.logdir, f'{id_0}_{epoch}.npy'), out_np_0) np.save( Path(self.writer.logdir, f'{id_0}_{epoch}_pred.npy'), pred_0) if epoch == 0: np.save(Path(self.writer.logdir, f'{id_0}_truth.npy'), truth_0) else: # save all test data for id_, item_truth, item_pred, item_out, threshold, T \ in zip(ids, intervals, prediction, out_np, thresholds, Ts): np.save(path_test_result / f'{id_}_truth.npy', item_truth) np.save(path_test_result / f'{id_}.npy', item_out[:T]) np.save(path_test_result / f'{id_}_pred.npy', item_pred) all_thresholds[str(id_)] = threshold str_eval = np.array2string(eval_result / n_batch, precision=3) pbar.set_postfix_str(f'{loss / n_batch:.3f}, {str_eval}') avg_loss += loss avg_eval += eval_result avg_loss = avg_loss / len(dataloader.dataset) avg_eval = avg_eval / len(dataloader.dataset) if mode == 'test': np.savez(path_test_result / f'thresholds.npz', **all_thresholds) return avg_loss, avg_eval def step(self, valid_f1: float, epoch: int): """ :param valid_f1: :param epoch: :return: test epoch or 0 """ last_restart = self.scheduler.last_restart self.scheduler.step() # scheduler.last_restart can be updated if epoch == self.scheduler.last_restart: if valid_f1 < self.f1_last_restart: return last_restart else: self.f1_last_restart = valid_f1 torch.save(self.model.module.state_dict(), Path(self.writer.logdir, f'{epoch}.pt')) return 0
def train_unet(epoch=100): # Get all images in train set image_names = os.listdir('dataset/train/images/') image_names = [name for name in image_names if name.endswith(('.jpg', '.JPG', '.png'))] # Split into train and validation sets np.random.shuffle(image_names) split = int(len(image_names) * 0.9) train_image_names = image_names[:split] val_image_names = image_names[split:] # Create a dataset train_dataset = EggsPansDataset('dataset/train', train_image_names, mode='train') val_dataset = EggsPansDataset('dataset/train', val_image_names, mode='val') # Create a dataloader train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=False, num_workers=0) val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0) # Initialize model and transfer to device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = UNet() model = model.to(device) optim = torch.optim.Adam(model.parameters(), lr=0.0001) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='max', verbose=True) loss_obj = EggsPansLoss() metrics_obj = EggsPansMetricIoU() # Keep best IoU and checkpoint best_iou = 0.0 # Train epochs for epoch_idx in range(epoch): print('Epoch: {:2}/{}'.format(epoch_idx + 1, epoch)) # Reset metrics and loss loss_obj.reset_loss() metrics_obj.reset_iou() # Train phase model.train() # Train epoch pbar = tqdm(train_dataloader) for imgs, egg_masks, pan_masks in pbar: # Convert to device imgs = imgs.to(device) gt_egg_masks = egg_masks.to(device) gt_pan_masks = pan_masks.to(device) # Zero gradients optim.zero_grad() # Forward through net, and get the loss pred_egg_masks, pred_pan_masks = model(imgs) loss = loss_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks]) iou = metrics_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks]) # Compute gradients and compute them loss.backward() optim.step() # Update metrics pbar.set_description('Loss: {:5.6f}, IoU: {:5.6f}'.format(loss_obj.get_running_loss(), metrics_obj.get_running_iou())) print('Validation: ') # Reset metrics and loss loss_obj.reset_loss() metrics_obj.reset_iou() # Val phase model.eval() # Val epoch pbar = tqdm(val_dataloader) for imgs, egg_masks, pan_masks in pbar: # Convert to device imgs = imgs.to(device) gt_egg_masks = egg_masks.to(device) gt_pan_masks = pan_masks.to(device) with torch.no_grad(): # Forward through net, and get the loss pred_egg_masks, pred_pan_masks = model(imgs) loss = loss_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks]) iou = metrics_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks]) pbar.set_description('Val Loss: {:5.6f}, IoU: {:5.6f}'.format(loss_obj.get_running_loss(), metrics_obj.get_running_iou())) # Save best model if best_iou < metrics_obj.get_running_iou(): best_iou = metrics_obj.get_running_iou() torch.save(model.state_dict(), os.path.join('checkpoints/', 'epoch_{}_{:.4f}.pth'.format( epoch_idx + 1, metrics_obj.get_running_iou()))) # Reduce learning rate on plateau lr_scheduler.step(metrics_obj.get_running_iou()) print('\n') print('-'*100)
class Tester: @classmethod def partition_masks(cls, output, target): # Partition the union of the output and target into a true positive mask, # a false positive mask, and a false negative mask true_positive_mask = torch.min(output, target) false_positive_mask = output - true_positive_mask false_negative_mask = target - true_positive_mask return true_positive_mask, false_positive_mask, false_negative_mask @classmethod def get_partition_measures(cls, output, target): true_positive_mask, false_positive_mask, false_negative_mask = Tester.partition_masks(output, target) tp = torch.sum(true_positive_mask) / (torch.sum(true_positive_mask) + torch.sum(false_positive_mask)) fp = torch.sum(false_positive_mask) / (torch.sum(true_positive_mask) + torch.sum(false_positive_mask)) fn = torch.sum(false_negative_mask) / (torch.sum(true_positive_mask) + torch.sum(false_negative_mask)) return tp, fp, fn @classmethod def get_dice(cls, output, target): tp, fp, fn = Tester.get_partition_measures(output, target) if tp + fp + fn == 0: return -1 dice = (2*tp)/(2*tp + fp + fn) if math.isnan(dice): return 0 return dice.item() @classmethod def get_intersection_over_union(cls, output, target): tp, fp, fn = Tester.get_partition_measures(output, target) if tp + fp + fn == 0: return -1 iou = tp / (tp + fp + fn) if math.isnan(iou): return 0 return iou.item() @classmethod def get_accuracy(cls, output, target): tp, fp, fn = Tester.get_partition_measures(output, target) if tp + fp == 0: return -1 accuracy = tp / (tp + fp) if math.isnan(accuracy): return 0 return accuracy.item() @classmethod def get_recall(cls, output, target): tp, fp, fn = Tester.get_partition_measures(output, target) if tp + fn == 0: return -1 recall = tp / (tp + fn) if math.isnan(recall): return 0 return recall.item() @classmethod def get_number_of_batches(cls, image_paths, batch_size): batches = len(image_paths) / batch_size if not batches.is_integer(): batches = math.floor(batches) + 1 return int(batches) @classmethod def evaluate_loss(cls, criterion, output, target): loss_1 = criterion(output, target) loss_2 = 1 - Tester.get_intersection_over_union(output, target) loss = loss_1 + 0.1 * loss_2 return loss def __init__(self, side_length, batch_size, seed, image_paths, state_dict): self.side_length = side_length self.batch_size = batch_size self.seed = seed self.image_paths = glob.glob(image_paths) self.batches = Tester.get_number_of_batches(self.image_paths, self.batch_size) self.model = UNet() self.loader = Loader(self.side_length) self.state_dict = state_dict def set_cuda(self): if torch.cuda.is_available(): self.model = self.model.cuda() def set_seed(self): if self.seed is not None: np.random.seed(self.seed) def process_batch(self, batch): # Grab a batch, shuffled according to the provided seed. Note that # i-th image: samples[i][0], i-th mask: samples[i][1] samples = Loader.get_batch(self.image_paths, self.batch_size, batch, self.seed) samples.astype(float) # Cast samples into torch.FloatTensor for interaction with U-Net samples = torch.from_numpy(samples) samples = samples.float() # Cast into a CUDA tensor, if GPUs are available if torch.cuda.is_available(): samples = samples.cuda() # Isolate images and their masks samples_images = samples[:, 0] samples_masks = samples[:, 1] # Reshape for interaction with U-Net samples_images = samples_images.unsqueeze(1) samples_masks = samples_masks.unsqueeze(1) # Run inputs through the model output = self.model(samples_images) # Clamp the target for proper interaction with BCELoss target = torch.clamp(samples_masks, min=0, max=1) del samples return output, target def test_model(self): if torch.cuda.is_available(): buffered_state_dict = torch.load("weights/" + self.state_dict) else: buffered_state_dict = torch.load("weights/" + self.state_dict, map_location=lambda storage, loc: storage) self.model.load_state_dict(buffered_state_dict) self.model.eval() criterion = nn.BCELoss() perfect_accuracy_count = 0 zero_accuracy_count = 0 image_count = 0 accuracy_list = [] recall_list = [] iou_list = [] dice_list = [] losses_list = [] for batch in range(self.batches): output, target = self.process_batch(batch) loss = Tester.evaluate_loss(criterion, output, target) print("Batch:", batch) print("~~~~~~~~~~~~~~~~~~~~~~~~~~") print("~~~~~~~~~~~~~~~~~~~~~~~~~~") for i in range(0, output.shape[0]): image_count += 1 binary_mask = Editor.make_binary_mask_from_torch(output[i, :, :, :], 1.0) # Metrics accuracy = Tester.get_accuracy(binary_mask, target[i, :, :, :].cpu()) recall = Tester.get_recall(binary_mask, target[i, :, :, :].cpu()) iou = Tester.get_intersection_over_union(binary_mask, target[i, :, :, :].cpu()) dice = Tester.get_dice(binary_mask, target[i, :, :, :].cpu()) if accuracy == 1: perfect_accuracy_count += 1 if accuracy == 0: zero_accuracy_count += 1 accuracy_list.append(accuracy) recall_list.append(recall) iou_list.append(iou) dice_list.append(dice) print("Accuracy:", accuracy) print("Recall:", recall) print("IoU:", iou) print("Dice:", dice,"\n") print("Mean Accuracy:", mean(accuracy_list)) print("Mean Recall:", mean(recall_list)) print("Mean IoU:", mean(iou_list)) print("Mean Dice:", mean(dice_list)) print("~~~~~~~~~~~~~~~~~~~~~~~~~~") loss_value = loss.item() losses_list.append(loss_value) print("Test loss:", loss_value) print("~~~~~~~~~~~~~~~~~~~~~~~~~~") del output del target mean_iou = mean(iou_list) mean_accuracy = mean(accuracy_list) mean_recall = mean(recall_list) mean_dice = mean(dice_list) print("Perfect Accuracy Percentage:", perfect_accuracy_count / image_count) print("Zero Accuracy Percentage:", zero_accuracy_count / image_count) print("Mean Accuracy:", mean_accuracy) print("Mean Recall:", mean_recall) print("Mean IoU:", mean_iou) print("Mean Dice:", mean_dice)
def video_infer(args): cap = cv2.VideoCapture(args.video) _, frame = cap.read() H, W = frame.shape[:2] fps = cap.get(cv2.CAP_PROP_FPS) out = cv2.VideoWriter(args.output, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (W, H)) # Background if args.bg is not None: BACKGROUND = cv2.imread(args.bg)[..., ::-1] BACKGROUND = cv2.resize(BACKGROUND, (W, H), interpolation=cv2.INTER_LINEAR) KERNEL_SZ = 25 SIGMA = 0 # Alpha transperency else: COLOR1 = [255, 0, 0] COLOR2 = [0, 0, 255] if args.model == 'unet': model = UNet(backbone=args.net, num_classes=2, pretrained_backbone=None) elif args.model == 'deeplabv3_plus': model = DeepLabV3Plus(backbone=args.net, num_classes=2, pretrained_backbone=None) if args.use_cuda: model = model.cuda() trained_dict = torch.load(args.checkpoint, map_location="cpu")['state_dict'] model.load_state_dict(trained_dict, strict=False) model.eval() if W > H: w_new = int(args.input_sz) h_new = int(H * w_new / W) else: h_new = int(args.input_sz) w_new = int(W * h_new / H) disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST) prev_gray = np.zeros((h_new, w_new), np.uint8) prev_cfd = np.zeros((h_new, w_new), np.float32) is_init = True while (cap.isOpened()): start_time = time() ret, frame = cap.read() if ret: image = frame[..., ::-1] h, w = image.shape[:2] read_cam_time = time() # Predict mask X, pad_up, pad_left, h_new, w_new = utils.preprocessing( image, expected_size=args.input_sz, pad_value=0) preproc_time = time() with torch.no_grad(): if args.use_cuda: mask = model(X.cuda()) mask = mask[..., pad_up:pad_up + h_new, pad_left:pad_left + w_new] #mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=True) mask = F.softmax(mask, dim=1) mask = mask[0, 1, ...].cpu().numpy() #(213, 320) else: mask = model(X) mask = mask[..., pad_up:pad_up + h_new, pad_left:pad_left + w_new] #mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=True) mask = F.softmax(mask, dim=1) mask = mask[0, 1, ...].numpy() predict_time = time() # optical tracking cur_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) cur_gray = cv2.resize(cur_gray, (w_new, h_new)) scoremap = 255 * mask optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, disflow, is_init) optical_flow_track_time = time() prev_gray = cur_gray.copy() prev_cfd = optflow_map.copy() is_init = False optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0) optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8) img_matting = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2) bg_im = np.ones_like(img_matting) * 255 re_image = cv2.resize(image, (w_new, h_new)) comb = (img_matting * re_image + (1 - img_matting) * bg_im).astype( np.uint8) comb = cv2.resize(comb, (W, H)) comb = comb[..., ::-1] # Print runtime read = read_cam_time - start_time preproc = preproc_time - read_cam_time pred = predict_time - preproc_time optical = optical_flow_track_time - predict_time total = read + preproc + pred + optical print( "read: %.3f [s]; preproc: %.3f [s]; pred: %.3f [s]; optical: %.3f [s]; total: %.3f [s]; fps: %.2f [Hz]" % (read, preproc, pred, optical, total, 1 / pred)) out.write(comb) if args.watch: cv2.imshow('webcam', comb[..., ::-1]) if cv2.waitKey(1) & 0xFF == ord('q'): break else: break cap.release() out.release()
def video_infer(args): cap = cv2.VideoCapture(args.video) _, frame = cap.read() H, W = frame.shape[:2] fourcc = cv2.VideoWriter_fourcc(*'DIVX') out = cv2.VideoWriter(args.output, fourcc, 30, (W,H)) font = cv2.FONT_HERSHEY_SIMPLEX # Background if args.bg is not None: BACKGROUND = cv2.imread(args.bg)[...,::-1] BACKGROUND = cv2.resize(BACKGROUND, (W,H), interpolation=cv2.INTER_LINEAR) KERNEL_SZ = 25 SIGMA = 0 # Alpha transperency else: COLOR1 = [90, 140, 154] COLOR2 = [0, 0, 0] if args.model=='unet': model = UNet(backbone=args.net, num_classes=2, pretrained_backbone=None) elif args.model=='deeplabv3_plus': model = DeepLabV3Plus(backbone=args.net, num_classes=2, pretrained_backbone=None) elif args.model=='hrnet': model = HighResolutionNet(num_classes=2, pretrained_backbone=None) if args.use_cuda: model = model.cuda() trained_dict = torch.load(args.checkpoint, map_location="cpu")['state_dict'] model.load_state_dict(trained_dict, strict=False) model.eval() while(cap.isOpened()): start_time = time() ret, frame = cap.read() if ret: image = frame[...,::-1] h, w = image.shape[:2] read_cam_time = time() # Predict mask X, pad_up, pad_left, h_new, w_new = utils.preprocessing(image, expected_size=args.input_sz, pad_value=0) preproc_time = time() with torch.no_grad(): if args.use_cuda: mask = model(X.cuda()) if mask.shape[1] != h_new: mask = F.interpolate(mask, size=(args.input_sz, args.input_sz), mode='bilinear', align_corners=True) mask = mask[..., pad_up: pad_up+h_new, pad_left: pad_left+w_new] mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=True) mask = F.softmax(mask, dim=1) mask = mask[0,1,...].cpu().numpy() else: mask = model(X) mask = mask[..., pad_up: pad_up+h_new, pad_left: pad_left+w_new] mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=True) mask = F.softmax(mask, dim=1) mask = mask[0,1,...].numpy() predict_time = time() # Draw result if args.bg is None: image_alpha = utils.draw_matting(image, mask) #image_alpha = utils.draw_transperency(image, mask, COLOR1, COLOR2) else: image_alpha = utils.draw_fore_to_back(image, mask, BACKGROUND, kernel_sz=KERNEL_SZ, sigma=SIGMA) draw_time = time() # Print runtime read = read_cam_time-start_time preproc = preproc_time-read_cam_time pred = predict_time-preproc_time draw = draw_time-predict_time total = read + preproc + pred + draw fps = 1 / pred print("read: %.3f [s]; preproc: %.3f [s]; pred: %.3f [s]; draw: %.3f [s]; total: %.3f [s]; fps: %.2f [Hz]" % (read, preproc, pred, draw, total, fps)) # Wait for interupt cv2.putText(image_alpha, "%.2f [fps]" % (fps), (10, 50), font, 1.5, (0, 255, 0), 2, cv2.LINE_AA) out.write(image_alpha[..., ::-1]) if args.watch: cv2.imshow('webcam', image_alpha[..., ::-1]) if cv2.waitKey(1) & 0xFF == ord('q'): break else: break
def main(): use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') unet = UNet(input_dim, label_dim) unet.load_state_dict( torch.load('./checkpoints_1_19/checkpoint_30.pth', map_location='cpu')) unet.eval() if use_cuda: unet.cuda() criterion = nn.MSELoss() """ if sys.platform.startswith('win'): num_workers = 0 # 0表示不用额外的进程来加速读取数据 else: num_workers = 4 # 读取训练数据集 batch_size = 512 #准备数据 mnist_test_dataset_with_noise = MyMnistDataSet.MyMnistDataSet(root_dir='./mnist_dataset_noise', label_root_dir='./mnist_dataset', type_name='test', transform=transforms.ToTensor()) test_data_loader_with_noise = torch.utils.data.DataLoader(mnist_test_dataset_with_noise, batch_size, shuffle=False, num_workers=num_workers) #遍历数据已有模型进行reference test_loss_sum, batch_count, start_time = 0.0, 0, time.time() for X, y in test_data_loader_with_noise: X = X.to(device) y = y.to(device) y_hat = unet(X) l = criterion(y_hat, y) test_loss_sum += l.cpu().item() batch_count += 1 print('predict: batch_cout %d, test loss %.4f, time %.1f sec' % (batch_count, test_loss_sum / batch_count, time.time() - start_time)) """ transform = transforms.Compose([ transforms.CenterCrop(256), transforms.ToTensor(), ]) dataset = SpectralDataSet( root_dir= '/mnt/liguanlin/DataSets/lowlight_hyperspectral_datasets/band_splited_dataset', type_name='test', transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) count = 0 mean_psnr = 0 mean_ssim = 0 total_psnr = 0 total_ssim = 0 goundtruth_total_pnsr = 0 goundtruth_total_ssim = 0 for real, labels in tqdm(dataloader): #print('real.shape', real.shape) #print('labels.shape', labels.shape) cur_batch_size = len(real) # Flatten the image real = real.to(device) labels = labels.to(device) pred = unet(real) print(pred.shape) print(real.shape) save_image(pred, 'pred.png', nrow=2) save_image(labels, 'labels.png', nrow=2) pred_numpy = pred.detach().cpu().numpy() label_numpy = labels.detach().cpu().numpy() origin_numpy = real.detach().cpu().numpy() print(pred_numpy.shape) for i in range(pred_numpy.shape[0]): numpy_img = pred_numpy[i].reshape( (pred_numpy.shape[2], pred_numpy.shape[3])) pred_numpy_img = numpy_img * 1023 pred_numpy_img = pred_numpy_img.astype(np.int16) image = Image.fromarray(pred_numpy_img) iamge_name = "./test_results/pred/" + str(i) + ".png" image.save(iamge_name) #pred_numpy_img_8bit = numpy_img * 255 label_numpy_img = label_numpy[i].reshape( (label_numpy.shape[2], label_numpy.shape[3])) label_numpy_img = label_numpy_img * 1023 label_numpy_img = label_numpy_img.astype(np.int16) label_image = Image.fromarray(label_numpy_img) label_iamge_name = "./test_results/label/" + str(i) + ".png" label_image.save(label_iamge_name) origin_numpy_img = origin_numpy[i].reshape( (origin_numpy.shape[2], origin_numpy.shape[3])) origin_numpy_img = origin_numpy_img * 1023 origin_numpy_img = origin_numpy_img.astype(np.int16) origin_image = Image.fromarray(origin_numpy_img) origin_iamge_name = "./test_results/original/" + str(i) + ".png" origin_image.save(origin_iamge_name) count += 1 total_psnr += caculate_psnr_16bit(pred_numpy_img, origin_numpy_img) total_ssim += caculate_ssim_16bit(pred_numpy_img, label_numpy_img) goundtruth_total_pnsr += caculate_psnr_16bit( label_numpy_img, origin_numpy_img) goundtruth_total_ssim += caculate_ssim_16bit( label_numpy_img, pred_numpy_img) if (count == 4): break mean_psnr = total_psnr / count mean_ssim = total_ssim / count print("count = ", count) print("mean_psnr = ", mean_psnr) print("mean_ssim = ", mean_ssim) gound_truth_mean_psnr = goundtruth_total_pnsr / count gound_truth_mean_ssim = goundtruth_total_ssim / count print("gound_truth_mean_psnr = ", gound_truth_mean_psnr) print("gound_truth_mean_ssim = ", gound_truth_mean_ssim)
def train_val(config): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') train_loader = get_dataloader(img_dir=config.train_img_dir, mask_dir=config.train_mask_dir, mode="train", batch_size=config.batch_size, num_workers=config.num_workers, smooth=config.smooth) val_loader = get_dataloader(img_dir=config.val_img_dir, mask_dir=config.val_mask_dir, mode="val", batch_size=config.batch_size, num_workers=config.num_workers) writer = SummaryWriter( comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" % (config.lr, config.batch_size, config.model_type, config.data_type)) if config.model_type == "UNet": model = UNet() elif config.model_type == "UNet++": model = UNetPP() elif config.model_type == "SEDANet": model = SEDANet() elif config.model_type == "RefineNet": model = rf101() elif config.model_type == "DANet": # src = "./pretrained/60_DANet_0.8086.pth" # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict() # print("load pretrained params from stage 1: " + src) # pretrained_dict.pop('seg1.1.weight') # pretrained_dict.pop('seg1.1.bias') model = DANet(backbone='resnext101', nclass=config.output_ch, pretrained=True, norm_layer=nn.BatchNorm2d) # model_dict = model.state_dict() # model_dict.update(pretrained_dict) # model.load_state_dict(model_dict) elif config.model_type == "Deeplabv3+": # src = "./pretrained/Deeplabv3+.pth" # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict() # print("load pretrained params from stage 1: " + src) # # print(pretrained_dict.keys()) # for key in list(pretrained_dict.keys()): # if key.split('.')[0] == "cbr_last": # pretrained_dict.pop(key) model = deeplabv3_plus.DeepLabv3_plus(in_channels=3, num_classes=config.output_ch, backend='resnet101', os=16, pretrained=True, norm_layer=nn.BatchNorm2d) # model_dict = model.state_dict() # model_dict.update(pretrained_dict) # model.load_state_dict(model_dict) elif config.model_type == "HRNet_OCR": model = seg_hrnet_ocr.get_seg_model() elif config.model_type == "scSEUNet": model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d) else: model = UNet() if config.iscontinue: model = torch.load("./exp/13_Deeplabv3+_0.7619.pth", map_location='cpu').module for k, m in model.named_modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) model = model.to(device) labels = [1, 2, 3, 4, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] objects = [ '水体', '道路', '建筑物', '机场', '停车场', '操场', '普通耕地', '农业大棚', '自然草地', '绿地绿化', '自然林', '人工林', '自然裸土', '人为裸土', '其它' ] frequency = np.array([ 0.0279, 0.0797, 0.1241, 0.00001, 0.0616, 0.0029, 0.2298, 0.0107, 0.1207, 0.0249, 0.1470, 0.0777, 0.0617, 0.0118, 0.0187 ]) if config.optimizer == "sgd": optimizer = SGD(model.parameters(), lr=config.lr, weight_decay=1e-4, momentum=0.9) elif config.optimizer == "adamw": optimizer = adamw.AdamW(model.parameters(), lr=config.lr) else: optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device) # criterion = nn.CrossEntropyLoss(weight=weight) if config.smooth == "all": criterion = LabelSmoothSoftmaxCE() elif config.smooth == "edge": criterion = LabelSmoothCE() else: criterion = nn.CrossEntropyLoss() # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True) scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, eta_min=1e-4) global_step = 0 max_fwiou = 0 for epoch in range(config.num_epochs): epoch_loss = 0.0 seed = np.random.randint(0, 2, 1) seed = 0 print("seed is ", seed) if seed == 1: train_loader = get_dataloader(img_dir=config.train_img_dir, mask_dir=config.train_mask_dir, mode="train", batch_size=config.batch_size // 2, num_workers=config.num_workers, smooth=config.smooth) val_loader = get_dataloader(img_dir=config.val_img_dir, mask_dir=config.val_mask_dir, mode="val", batch_size=config.batch_size // 2, num_workers=config.num_workers) else: train_loader = get_dataloader(img_dir=config.train_img_dir, mask_dir=config.train_mask_dir, mode="train", batch_size=config.batch_size, num_workers=config.num_workers, smooth=config.smooth) val_loader = get_dataloader(img_dir=config.val_img_dir, mask_dir=config.val_mask_dir, mode="val", batch_size=config.batch_size, num_workers=config.num_workers) cm = np.zeros([15, 15]) print(optimizer.param_groups[0]['lr']) with tqdm(total=config.num_train, desc="Epoch %d / %d" % (epoch + 1, config.num_epochs), unit='img', ncols=100) as train_pbar: model.train() for image, mask in train_loader: image = image.to(device, dtype=torch.float32) if seed == 0: pass elif seed == 1: image = F.interpolate(image, size=(384, 384), mode='bilinear', align_corners=True) mask = F.interpolate(mask.float(), size=(384, 384), mode='nearest') if config.smooth == "edge": mask = mask.to(device, dtype=torch.float32) else: mask = mask.to(device, dtype=torch.long).argmax(dim=1) aux_out, out = model(image) aux_loss = criterion(aux_out, mask) seg_loss = criterion(out, mask) loss = aux_loss + seg_loss # pred = model(image) # loss = criterion(pred, mask) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) train_pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() train_pbar.update(image.shape[0]) global_step += 1 # if global_step > 10: # break # scheduler.step() print("\ntraining epoch loss: " + str(epoch_loss / (float(config.num_train) / (float(config.batch_size))))) torch.cuda.empty_cache() val_loss = 0 with torch.no_grad(): with tqdm(total=config.num_val, desc="Epoch %d / %d validation round" % (epoch + 1, config.num_epochs), unit='img', ncols=100) as val_pbar: model.eval() locker = 0 for image, mask in val_loader: image = image.to(device, dtype=torch.float32) target = mask.to(device, dtype=torch.long).argmax(dim=1) mask = mask.cpu().numpy() _, pred = model(image) val_loss += F.cross_entropy(pred, target).item() pred = pred.cpu().detach().numpy() mask = semantic_to_mask(mask, labels) pred = semantic_to_mask(pred, labels) cm += get_confusion_matrix(mask, pred, labels) val_pbar.update(image.shape[0]) if locker == 5: writer.add_images('mask_a/true', mask[2, :, :], epoch + 1, dataformats='HW') writer.add_images('mask_a/pred', pred[2, :, :], epoch + 1, dataformats='HW') writer.add_images('mask_b/true', mask[3, :, :], epoch + 1, dataformats='HW') writer.add_images('mask_b/pred', pred[3, :, :], epoch + 1, dataformats='HW') locker += 1 # break miou = get_miou(cm) fw_miou = (miou * frequency).sum() scheduler.step() if True: if torch.__version__ == "1.6.0": torch.save(model, config.result_path + "/%d_%s_%.4f.pth" % (epoch + 1, config.model_type, fw_miou), _use_new_zipfile_serialization=False) else: torch.save( model, config.result_path + "/%d_%s_%.4f.pth" % (epoch + 1, config.model_type, fw_miou)) max_fwiou = fw_miou print("\n") print(miou) print("testing epoch loss: " + str(val_loss), "FWmIoU = %.4f" % fw_miou) writer.add_scalar('FWIoU/val', fw_miou, epoch + 1) writer.add_scalar('loss/val', val_loss, epoch + 1) for idx, name in enumerate(objects): writer.add_scalar('iou/val' + name, miou[idx], epoch + 1) torch.cuda.empty_cache() writer.close() print("Training finished")
def train_val(config): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') train_loader = get_dataloader(img_dir=config.train_img_dir, mask_dir=config.train_mask_dir, mode="train", batch_size=config.batch_size, num_workers=config.num_workers) val_loader = get_dataloader(img_dir=config.val_img_dir, mask_dir=config.val_mask_dir, mode="val", batch_size=config.batch_size, num_workers=config.num_workers) writer = SummaryWriter( comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" % (config.lr, config.batch_size, config.model_type, config.data_type)) if config.model_type not in [ 'UNet', 'R2UNet', 'AUNet', 'R2AUNet', 'SEUNet', 'SEUNet++', 'UNet++', 'DAUNet', 'DANet', 'AUNetR', 'RendDANet', "BASNet" ]: print('ERROR!! model_type should be selected in supported models') print('Choose model %s' % config.model_type) return if config.model_type == "UNet": model = UNet() elif config.model_type == "AUNet": model = AUNet() elif config.model_type == "R2UNet": model = R2UNet() elif config.model_type == "SEUNet": model = SEUNet(useCSE=False, useSSE=False, useCSSE=True) elif config.model_type == "UNet++": model = UNetPP() elif config.model_type == "DANet": model = DANet(backbone='resnet101', nclass=1) elif config.model_type == "AUNetR": model = AUNet_R16(n_classes=1, learned_bilinear=True) elif config.model_type == "RendDANet": model = RendDANet(backbone='resnet101', nclass=1) elif config.model_type == "BASNet": model = BASNet(n_channels=3, n_classes=1) else: model = UNet() if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) model = model.to(device, dtype=torch.float) if config.optimizer == "sgd": optimizer = SGD(model.parameters(), lr=config.lr, weight_decay=1e-6, momentum=0.9) else: optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) if config.loss == "dice": criterion = DiceLoss() elif config.loss == "bce": criterion = nn.BCELoss() elif config.loss == "bas": criterion = BasLoss() else: criterion = MixLoss() scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) global_step = 0 best_dice = 0.0 for epoch in range(config.num_epochs): epoch_loss = 0.0 with tqdm(total=config.num_train, desc="Epoch %d / %d" % (epoch + 1, config.num_epochs), unit='img') as train_pbar: model.train() for image, mask in train_loader: image = image.to(device, dtype=torch.float) mask = mask.to(device, dtype=torch.float) d0, d1, d2, d3, d4, d5, d6, d7 = model(image) loss = criterion(d0, d1, d2, d3, d4, d5, d6, d7, mask) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) train_pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() train_pbar.update(image.shape[0]) global_step += 1 # if global_step % 100 == 0: # writer.add_images('masks/true', mask, global_step) # writer.add_images('masks/pred', d0 > 0.5, global_step) scheduler.step() epoch_dice = 0.0 epoch_acc = 0.0 epoch_sen = 0.0 epoch_spe = 0.0 epoch_pre = 0.0 current_num = 0 with tqdm(total=config.num_val, desc="Epoch %d / %d validation round" % (epoch + 1, config.num_epochs), unit='img') as val_pbar: model.eval() locker = 0 for image, mask in val_loader: current_num += image.shape[0] image = image.to(device, dtype=torch.float) mask = mask.to(device, dtype=torch.float) d0, d1, d2, d3, d4, d5, d6, d7 = model(image) batch_dice = dice_coeff(mask, d0).item() epoch_dice += batch_dice * image.shape[0] epoch_acc += get_accuracy(pred=d0, true=mask) * image.shape[0] epoch_sen += get_sensitivity(pred=d0, true=mask) * image.shape[0] epoch_spe += get_specificity(pred=d0, true=mask) * image.shape[0] epoch_pre += get_precision(pred=d0, true=mask) * image.shape[0] if locker == 200: writer.add_images('masks/true', mask, epoch + 1) writer.add_images('masks/pred', d0 > 0.5, epoch + 1) val_pbar.set_postfix(**{'dice (batch)': batch_dice}) val_pbar.update(image.shape[0]) locker += 1 epoch_dice /= float(current_num) epoch_acc /= float(current_num) epoch_sen /= float(current_num) epoch_spe /= float(current_num) epoch_pre /= float(current_num) epoch_f1 = get_F1(SE=epoch_sen, PR=epoch_pre) if epoch_dice > best_dice: best_dice = epoch_dice writer.add_scalar('Best Dice/test', best_dice, epoch + 1) torch.save( model, config.result_path + "/%s_%s_%d.pth" % (config.model_type, str(epoch_dice), epoch + 1)) logging.info('Validation Dice Coeff: {}'.format(epoch_dice)) print("epoch dice: " + str(epoch_dice)) writer.add_scalar('Dice/test', epoch_dice, epoch + 1) writer.add_scalar('Acc/test', epoch_acc, epoch + 1) writer.add_scalar('Sen/test', epoch_sen, epoch + 1) writer.add_scalar('Spe/test', epoch_spe, epoch + 1) writer.add_scalar('Pre/test', epoch_pre, epoch + 1) writer.add_scalar('F1/test', epoch_f1, epoch + 1) writer.close() print("Training finished")
def train_val(config): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') train_loader = get_dataloader(img_dir=config.train_img_dir, mask_dir=config.train_mask_dir, mode="train", batch_size=config.batch_size, num_workers=config.num_workers, smooth=config.smooth) val_loader = get_dataloader(img_dir=config.val_img_dir, mask_dir=config.val_mask_dir, mode="val", batch_size=4, num_workers=config.num_workers) writer = SummaryWriter( comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" % (config.lr, config.batch_size, config.model_type, config.data_type)) if config.model_type == "UNet": model = UNet() elif config.model_type == "UNet++": model = UNetPP() elif config.model_type == "SEDANet": model = SEDANet() elif config.model_type == "RefineNet": model = rf101() elif config.model_type == "BASNet": model = BASNet(n_classes=8) elif config.model_type == "DANet": model = DANet(backbone='resnet101', nclass=config.output_ch, pretrained=True, norm_layer=nn.BatchNorm2d) elif config.model_type == "Deeplabv3+": model = deeplabv3_plus.DeepLabv3_plus(in_channels=3, num_classes=8, backend='resnet101', os=16, pretrained=True, norm_layer=nn.BatchNorm2d) elif config.model_type == "HRNet_OCR": model = seg_hrnet_ocr.get_seg_model() elif config.model_type == "scSEUNet": model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d) else: model = UNet() if config.iscontinue: model = torch.load("./exp/24_Deeplabv3+_0.7825757691389714.pth").module for k, m in model.named_modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) model = model.to(device) labels = [100, 200, 300, 400, 500, 600, 700, 800] objects = ['水体', '交通建筑', '建筑', '耕地', '草地', '林地', '裸土', '其他'] if config.optimizer == "sgd": optimizer = SGD(model.parameters(), lr=config.lr, weight_decay=1e-4, momentum=0.9) elif config.optimizer == "adamw": optimizer = adamw.AdamW(model.parameters(), lr=config.lr) else: optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device) # criterion = nn.CrossEntropyLoss(weight=weight) criterion = BasLoss() # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True) scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, eta_min=1e-4) global_step = 0 max_fwiou = 0 frequency = np.array( [0.1051, 0.0607, 0.1842, 0.1715, 0.0869, 0.1572, 0.0512, 0.1832]) for epoch in range(config.num_epochs): epoch_loss = 0.0 cm = np.zeros([8, 8]) print(optimizer.param_groups[0]['lr']) with tqdm(total=config.num_train, desc="Epoch %d / %d" % (epoch + 1, config.num_epochs), unit='img', ncols=100) as train_pbar: model.train() for image, mask in train_loader: image = image.to(device, dtype=torch.float32) mask = mask.to(device, dtype=torch.float16) pred = model(image) loss = criterion(pred, mask) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) train_pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() train_pbar.update(image.shape[0]) global_step += 1 # if global_step > 10: # break # scheduler.step() print("\ntraining epoch loss: " + str(epoch_loss / (float(config.num_train) / (float(config.batch_size))))) torch.cuda.empty_cache() val_loss = 0 with torch.no_grad(): with tqdm(total=config.num_val, desc="Epoch %d / %d validation round" % (epoch + 1, config.num_epochs), unit='img', ncols=100) as val_pbar: model.eval() locker = 0 for image, mask in val_loader: image = image.to(device, dtype=torch.float32) target = mask.to(device, dtype=torch.long).argmax(dim=1) mask = mask.cpu().numpy() pred, _, _, _, _, _, _, _ = model(image) val_loss += F.cross_entropy(pred, target).item() pred = pred.cpu().detach().numpy() mask = semantic_to_mask(mask, labels) pred = semantic_to_mask(pred, labels) cm += get_confusion_matrix(mask, pred, labels) val_pbar.update(image.shape[0]) if locker == 25: writer.add_images('mask_a/true', mask[2, :, :], epoch + 1, dataformats='HW') writer.add_images('mask_a/pred', pred[2, :, :], epoch + 1, dataformats='HW') writer.add_images('mask_b/true', mask[3, :, :], epoch + 1, dataformats='HW') writer.add_images('mask_b/pred', pred[3, :, :], epoch + 1, dataformats='HW') locker += 1 # break miou = get_miou(cm) fw_miou = (miou * frequency).sum() scheduler.step() if fw_miou > max_fwiou: if torch.__version__ == "1.6.0": torch.save(model, config.result_path + "/%d_%s_%.4f.pth" % (epoch + 1, config.model_type, fw_miou), _use_new_zipfile_serialization=False) else: torch.save( model, config.result_path + "/%d_%s_%.4f.pth" % (epoch + 1, config.model_type, fw_miou)) max_fwiou = fw_miou print("\n") print(miou) print("testing epoch loss: " + str(val_loss), "FWmIoU = %.4f" % fw_miou) writer.add_scalar('mIoU/val', miou.mean(), epoch + 1) writer.add_scalar('FWIoU/val', fw_miou, epoch + 1) writer.add_scalar('loss/val', val_loss, epoch + 1) for idx, name in enumerate(objects): writer.add_scalar('iou/val' + name, miou[idx], epoch + 1) torch.cuda.empty_cache() writer.close() print("Training finished")
def main(): # 네트워크 G = UNet().to(device) D = Discriminator().to(device) # 네트워크 초기화 G.apply(weight_init) D.apply(weight_init) # pretrained 모델 불러오기 if args.reuse: assert os.path.isfile(args.save_path), '[!]Pretrained model not found' checkpoint = torch.load(args.save_path) G.load_state_dict(checkpoint['G']) D.load_state_dict(checkpoint['D']) print('[*]Pretrained model loaded') # optimizer G_optim = optim.Adam(G.parameters(), lr=args.lr, betas=(args.b1, args.b2)) D_optim = optim.Adam(D.parameters(), lr=args.lr, betas=(args.b1, args.b2)) for epoch in range(args.num_epoch): for i, imgs in enumerate(dataloader['train']): A = imgs['A'].to(device) B = imgs['B'].to(device) # # # # # # Discriminator # # # # # G.eval() D.train() fake = G(B) D_fake = D(fake, B) D_real = D(A, B) # original loss D loss_D = -((D_real.log() + (1 - D_fake).log()).mean()) # # LSGAN loss D # loss_D = ((D_real - 1)**2).mean() + (D_fake**2).mean() D_optim.zero_grad() loss_D.backward() D_optim.step() # # # # # # Generator # # # # # G.train() D.eval() fake = G(B) D_fake = D(fake, B) # original loss G loss_G = -(D_fake.mean().log() ) + args.lambda_recon * torch.abs(A - fake).mean() # # LSGAN loss G # loss_G = ((D_fake-1)**2).mean() + args.lambda_recon * torch.abs(A - fake).mean() G_optim.zero_grad() loss_G.backward() G_optim.step() # 학습 진행사항 출력 print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, args.num_epoch, i * args.batch_size, len(datasets['train']), loss_D.item(), loss_G.item())) # 이미지 저장 (save per epoch) val = next(iter(dataloader['test'])) real_A = val['A'].to(device) real_B = val['B'].to(device) with torch.no_grad(): fake_A = G(real_B) save_image(torch.cat([real_A, real_B, fake_A], dim=3), 'images/{0:03d}.png'.format(epoch + 1), nrow=2, normalize=True) # 모델 저장 torch.save({ 'G': G.state_dict(), 'D': D.state_dict(), }, args.save_path)
class Anonymizer: @classmethod def check_for_numpy(cls, tensor): if isinstance(tensor, torch.Tensor): tensor = tensor.detach().cpu().numpy() return tensor @classmethod def get_number_of_batches(cls, image_paths, batch_size): batches = len(image_paths) / batch_size if not batches.is_integer(): batches = math.floor(batches) + 1 return int(batches) @classmethod def apply_mask(cls, image, mask): return np.multiply(image, mask) @classmethod def anonymize_image(cls, image, mask): image = Anonymizer.check_for_numpy(image) image = image.reshape(image.shape[-2:]) image = np.float32(image) mask = Anonymizer.check_for_numpy(mask) mask = mask.reshape(mask.shape[-2:]) mask = np.uint8(mask) im = Anonymizer.apply_mask(image, mask) cv2.imwrite("sanity_mask.jpg", 255 * mask) cv2.imwrite("sanity_image.jpg", image) cv2.imwrite("sanity_join.jpg", im) mask = Editor.invert_mask(mask) mask = np.uint8(mask) im = cv2.inpaint(im, mask, 10, cv2.INPAINT_TELEA) cv2.imwrite("sanity_anon.jpg", im) return im def __init__(self, batch_size, image_paths, write_path, state_dict): self.batch_size = batch_size self.image_paths = glob.glob(image_paths) self.batches = Anonymizer.get_number_of_batches( self.image_paths, self.batch_size) self.write_path = write_path self.model = UNet() self.state_dict = state_dict def process_batch(self, batch): # Grab a batch, shuffled according to the provided seed. Note that # i-th image: samples[i][0], i-th mask: samples[i][1] samples = Loader.get_batch(self.image_paths, self.batch_size, batch, None) samples.astype(float) # Cast samples into torch.FloatTensor for interaction with U-Net samples = torch.from_numpy(samples) samples = samples.float() # Cast into a CUDA tensor, if GPUs are available if torch.cuda.is_available(): samples = samples.cuda() # Isolate images and their masks samples_images = samples[:, 0] samples_masks = samples[:, 1] # Reshape for interaction with U-Net samples_images = samples_images.unsqueeze(1) source = samples_images samples_masks = samples_masks.unsqueeze(1) # Run inputs through the model output = self.model(samples_images) # Clamp the target for proper interaction with BCELoss target = torch.clamp(samples_masks, min=0, max=1) del samples return source, output, target def anonymize(self): if not os.path.isdir(self.write_path): print("Making output directory") os.mkdir(self.write_path) count = 0 for batch in range(self.batches): source, output, target = self.process_batch(batch) source = Anonymizer.check_for_numpy(source) source = source.reshape(source.shape[-2:]) source = np.float32(source) binary_mask = Editor.make_binary_mask_from_torch( output[0, :, :, :], 1.0) inverted_binary_mask = Editor.invert_mask( binary_mask) # Now a numpy array instead of torch tensor anonymized_image = Anonymizer.anonymize_image( source, inverted_binary_mask) cv2.imwrite(self.write_path + "/orig_" + str(count) + ".jpg", source) cv2.imwrite(self.write_path + "/anon_" + str(count) + ".jpg", anonymized_image) count += 1 del batch, target, output, binary_mask, anonymized_image def set_cuda(self): if torch.cuda.is_available(): self.model = self.model.cuda() def set_weights(self): if torch.cuda.is_available(): buffered_state_dict = torch.load("weights/" + self.state_dict) else: buffered_state_dict = torch.load( "weights/" + self.state_dict, map_location=lambda storage, loc: storage) self.model.load_state_dict(buffered_state_dict) self.model.eval()
return img #-----------------------------------------------------------------------------------------------------------# #-----------------------------------------Model loading-----------------------------------------------------# # Load segmentation net model print("\nLoading Segmentation Network") segmentation_model = UNet(backbone="resnet18", num_classes=2) if torch.cuda.is_available(): trained_dict = torch.load(SEGMENTATION_NET_CHECKPOINT)['state_dict'] else: trained_dict = torch.load(SEGMENTATION_NET_CHECKPOINT, map_location="cpu")['state_dict'] segmentation_model.load_state_dict(trained_dict, strict=False) segmentation_model.to(device) segmentation_model.eval() # Create segmentation object segmentObj = VideoInference( model=segmentation_model, video_path=0, input_size=320, height=OUT_HEIGHT, width=OUT_WIDTH, use_cuda=torch.cuda.is_available(), draw_mode='matting', ) print("Done Loading Segmentation Network\n") # Load style transfer net model print("Loading Style Transfer Network")