def testwithdehaze(): model_dehaze = aodnet.AODnet() model_dehaze.load_state_dict( torch.load("./weight/dehaze.pth", map_location='cpu')) if args.choose_net == "Unet": model_segmentation = my_unet.UNet(3, 1) elif args.choose_net == "Enet": model_segmentation = enet.ENet(num_classes=1) elif args.choose_net == "Segnet": model_segmentation = segnet.SegNet(3, 1) model_segmentation.load_state_dict( torch.load(args.weight, map_location='cpu')) liver_dataset = LiverDataset_three("data/val_dehaze", transform=x_transform, target_transform=y_transform) dataloaders = DataLoader(liver_dataset) # batch_size默认为1 model_dehaze.eval() model_segmentation.eval() metric = SegmentationMetric(2) # import matplotlib.pyplot as plt # plt.ion() mean_acc, mean_miou = [], [] with torch.no_grad(): for src, rain, mask in dataloaders: y = model_dehaze(src) y1 = model_segmentation(y) y1 = torch.squeeze(y1).numpy() y_label = torch.squeeze(mask).numpy() y_label = y_label * 255 y1 = y1 * 127.5 # print(y_label.shape,y.shape) image = np.concatenate((y_label, y1)) if args.choose_net == "Unet": img_y = (y1 > 0.5) elif args.choose_net == "Enet": img_y = (y1 > 0.5) elif args.choose_net == "Segnet": img_y = (y1 > 0.5) elif args.choose_net == "Scnn": img_y = (y1 > 0.5) img_y = img_y.astype(int) y_label = y_label.astype(int) metric.addBatch(img_y, y_label) acc = metric.pixelAccuracy() mIoU = metric.meanIntersectionOverUnion() # confusionMatrix=metric.genConfusionMatrix(img_y, y_label) mean_acc.append(acc) mean_miou.append(mIoU) # print(acc, mIoU,confusionMatrix) print(acc, mIoU) # plt.imshow(image) # plt.pause(0.01) # plt.show() print("average acc:%0.6f average miou:%0.6f" % (np.mean(mean_acc), np.mean(mean_miou)))
def test_save_loss_graphs_no_class_weight(self): x = np.random.uniform(-1, 1, self.x_shape) x = Variable(x.astype(np.float32)) t = np.random.randint( 0, 12, (self.x_shape[0], self.x_shape[2], self.x_shape[3])) t = Variable(t.astype(np.int32)) for depth in six.moves.range(1, self.n_encdec + 1): model = segnet.SegNet(n_encdec=self.n_encdec, n_classes=12, in_channel=self.x_shape[1]) model = segnet.SegNetLoss(model, class_weight=None, train_depth=depth) y = model(x, t) cg = build_computational_graph([y], variable_style=_var_style, function_style=_func_style).dump() for e in range(1, self.n_encdec + 1): self.assertTrue( 'encdec{}'.format(e) in model.predictor._children) fn = 'tests/SegNet_xt_depth-{}_{}.dot'.format(self.n_encdec, depth) if os.path.exists(fn): continue with open(fn, 'w') as f: f.write(cg) subprocess.call('dot -Tpng {} -o {}'.format( fn, fn.replace('.dot', '.png')), shell=True)
def main(): nClasses = args.nClasses train_batch_size = 16 val_batch_size = 16 epochs = 50 img_height = 256 img_width = 256 root_path = '../../datasets/segmentation/' mode = 'seg' if nClasses == 2 else 'parse' train_file = './data/{}_train.txt'.format(mode) val_file = './data/{}_test.txt'.format(mode) if args.model == 'unet': model = unet.Unet(nClasses, input_height=img_height, input_width=img_width) elif args.model == 'segnet': model = segnet.SegNet(nClasses, input_height=img_height, input_width=img_width) else: raise ValueError( 'Does not support {}, only supports unet and segnet now'.format( args.model)) model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=1e-4), metrics=['accuracy']) model.summary() train = segdata_generator.generator(root_path, train_file, train_batch_size, nClasses, img_height, img_width) val = segdata_generator.generator(root_path, val_file, val_batch_size, nClasses, img_height, img_width, train=False) if not os.path.exists('./results/'): os.mkdir('./results') save_file = './weights/{}_seg_weights.h5'.format(args.model) if nClasses == 2 \ else './weights/{}_parse_weights.h5'.format(args.model) checkpoint = ModelCheckpoint(save_file, monitor='val_acc', save_best_only=True, save_weights_only=True, verbose=1) history = model.fit_generator( train, steps_per_epoch=12706 // train_batch_size, validation_data=val, validation_steps=5000 // val_batch_size, epochs=epochs, callbacks=[checkpoint], ) plot_history(history, './results/', args.model) save_history(history, './results/', args.model)
def trainwithdehaze(): model_dehaze = aodnet.AODnet().to(device) dsize = (3, 1, 256, 256) # inputs1 = torch.randn(dsize).to(device) # total_ops, total_params = profile(model_dehaze, (inputs1,), verbose=False) # print(" %.2f | %.2f" % (total_params / (1000 ** 2), total_ops / (1000 ** 3))) if args.choose_net == "Unet": model_segmentation = my_unet.UNet(3, 1).to(device) elif args.choose_net == "Enet": model_segmentation = enet.ENet(num_classes=1).to(device) elif args.choose_net == "Segnet": model_segmentation = segnet.SegNet(3, 1).to(device) # inputs2 = torch.randn(dsize).to(device) # total_ops, total_params = profile(model_segmentation, (inputs2,), verbose=False) # print(" %.2f | %.2f" % (total_params / (1000 ** 2), total_ops / (1000 ** 3))) batch_size = args.batch_size # dehaze的损失函数 criterion_dehaze = torch.nn.MSELoss() # dehaze的优化函数 optimizer_dehaze = optim.Adam(model_dehaze.parameters( )) # model.parameters():Returns an iterator over module parameters # 语义分割的损失函数 criterion_segmentation = torch.nn.BCELoss() # 语义分割的优化函数 optimizer_segmentation = optim.Adam(model_segmentation.parameters( )) # model.parameters():Returns an iterator over module parameters # 加载数据集 dataset_dehaze = LiverDataset_three("data/train_dehaze/", transform=x_transform, target_transform=y_transform) dataloader_dehaze = DataLoader(dataset_dehaze, batch_size=batch_size, shuffle=True, num_workers=4) # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用 # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度 train_dehaze_model(model_dehaze, model_segmentation, criterion_dehaze, criterion_segmentation, optimizer_dehaze, optimizer_segmentation, dataloader_dehaze, num_epochs=6)
def load_model(model_name, noc): if model_name == 'fcn': model = fcn.FCN8s(noc) if model_name == 'segnet': model = segnet.SegNet(3, noc) if model_name == 'pspnet': model = pspnet.PSPNet(noc) if model_name == 'unet': model = unet.UNet(noc) if model_name == 'segfast': model = segfast.SegFast(64, noc) if model_name == 'segfast_basic': model = segfast_basic.SegFast_Basic(64, noc) if model_name == 'segfast_mobile': model = segfast_mobile.SegFast_Mobile(noc) if model_name == 'segfast_v2_3': model = segfast_v2.SegFast_V2(64, noc, 3) if model_name == 'segfast_v2_5': model = segfast_v2.SegFast_V2(64, noc, 5) return model
def predict_segmentation(): n_classes = 2 images_path = '/home/deep/datasets/' val_file = './data/seg_test.txt' input_height = 256 input_width = 256 if args.model == 'unet': m = unet.Unet(n_classes, input_height=input_height, input_width=input_width) elif args.model == 'segnet': m = segnet.SegNet(n_classes, input_height=input_height, input_width=input_width) else: raise ValueError('Do not support {}'.format(args.model)) m.load_weights("./results/{}_weights.h5".format(args.model)) m.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) colors = np.array([[0, 0, 0], [255, 255, 255]]) i = 0 for x, y in generator(images_path, val_file, 1, n_classes, input_height, input_width): pr = m.predict(x)[0] pr = pr.reshape((input_height, input_width, n_classes)).argmax(axis=2) seg_img = np.zeros((input_height, input_width, 3)) for c in range(n_classes): seg_img[:, :, 0] += ((pr[:, :] == c) * (colors[c][0])).astype('uint8') seg_img[:, :, 1] += ((pr[:, :] == c) * (colors[c][1])).astype('uint8') seg_img[:, :, 2] += ((pr[:, :] == c) * (colors[c][2])).astype('uint8') cv2.imshow('test', seg_img) cv2.imwrite('./output/{}.jpg'.format(i), seg_img) i += 1 cv2.waitKey(30)
def test_save_normal_graphs(self): x = np.random.uniform(-1, 1, self.x_shape) x = Variable(x.astype(np.float32)) for depth in six.moves.range(1, self.n_encdec + 1): model = segnet.SegNet(n_encdec=self.n_encdec, in_channel=self.x_shape[1]) y = model(x, depth) cg = build_computational_graph([y], variable_style=_var_style, function_style=_func_style).dump() for e in range(1, self.n_encdec + 1): self.assertTrue('encdec{}'.format(e) in model._children) fn = 'tests/SegNet_x_depth-{}_{}.dot'.format(self.n_encdec, depth) if os.path.exists(fn): continue with open(fn, 'w') as f: f.write(cg) subprocess.call('dot -Tpng {} -o {}'.format( fn, fn.replace('.dot', '.png')), shell=True)
def test_remove_link(self): opt = optimizers.MomentumSGD(lr=0.01) # Update each depth for depth in six.moves.range(1, self.n_encdec + 1): model = segnet.SegNet(self.n_encdec, self.n_classes, self.x_shape[1], self.n_mid) model = segnet.SegNetLoss(model, class_weight=None, train_depth=depth) opt.setup(model) # Deregister non-target links from opt if depth > 1: model.predictor.remove_link('conv_cls') for d in range(1, self.n_encdec + 1): if d != depth: model.predictor.remove_link('encdec{}'.format(d)) for name, link in model.namedparams(): if depth > 1: self.assertTrue('encdec{}'.format(depth) in name) else: self.assertTrue('encdec{}'.format(depth) in name or 'conv_cls' in name)
def test(): if args.choose_net == "Unet": model = my_unet.UNet(3, 1).to(device) if args.choose_net == "My_Unet": model = my_unet.My_Unet2(3, 1).to(device) elif args.choose_net == "Enet": model = enet.ENet(num_classes=13).to(device) elif args.choose_net == "Segnet": model = segnet.SegNet(3, 1).to(device) elif args.choose_net == "CascadNet": model = my_cascadenet.CascadeNet(3, 1).to(device) elif args.choose_net == "my_drsnet_A": model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_B": model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_C": model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_A_direct_skip": model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3, out_ch=1).to(device) elif args.choose_net == "SEResNet": model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device) elif args.choose_net == "resnext_unet": model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device) elif args.choose_net == "resnet50_unet": model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3, out_ch=1).to(device) elif args.choose_net == "unet_res34": model = unet_res34.Resnet_Unet(in_ch=3, out_ch=1).to(device) elif args.choose_net == "dfanet": ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]] model = dfanet.DFANet(ch_cfg, 3, 1).to(device) elif args.choose_net == "cgnet": model = cgnet.Context_Guided_Network(1).to(device) elif args.choose_net == "lednet": model = lednet.Net(num_classes=1).to(device) elif args.choose_net == "bisenet": model = bisenet.BiSeNet(1, 'resnet18').to(device) elif args.choose_net == "espnet": model = espnet.ESPNet(classes=1).to(device) elif args.choose_net == "pspnet": model = pspnet.PSPNet(1).to(device) elif args.choose_net == "fddwnet": model = fddwnet.Net(classes=1).to(device) elif args.choose_net == "contextnet": model = contextnet.ContextNet(classes=1).to(device) elif args.choose_net == "linknet": model = linknet.LinkNet(classes=1).to(device) elif args.choose_net == "edanet": model = edanet.EDANet(classes=1).to(device) elif args.choose_net == "erfnet": model = erfnet.ERFNet(classes=1).to(device) dsize = (1, 3, 128, 192) inputs = torch.randn(dsize).to(device) total_ops, total_params = profile(model, (inputs, ), verbose=False) print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3))) model.load_state_dict(torch.load(args.weight)) liver_dataset = LiverDataset("data/val_camvid", transform=x_transform, target_transform=y_transform) dataloaders = DataLoader(liver_dataset) # batch_size默认为1 model.eval() metric = SegmentationMetric(13) # import matplotlib.pyplot as plt # plt.ion() multiclass = 1 mean_acc, mean_miou = [], [] alltime = 0.0 with torch.no_grad(): for x, y_label in dataloaders: x = x.to(device) start = time.time() y = model(x) usingtime = time.time() - start alltime = alltime + usingtime if multiclass == 1: # predict输出处理: # https://www.cnblogs.com/ljwgis/p/12313047.html y = F.sigmoid(y) y = y.cpu() # y = torch.squeeze(y).numpy() y = torch.argmax(y.squeeze(0), dim=0).data.numpy() print(y.max(), y.min()) # y_label = y_label[0] y_label = torch.squeeze(y_label).numpy() else: y = y.cpu() y = torch.squeeze(y).numpy() y_label = torch.squeeze(y_label).numpy() # img_y = y*127.5 if args.choose_net == "Unet": y = (y > 0.5) elif args.choose_net == "My_Unet": y = (y > 0.5) elif args.choose_net == "Enet": y = (y > 0.5) elif args.choose_net == "Segnet": y = (y > 0.5) elif args.choose_net == "Scnn": y = (y > 0.5) elif args.choose_net == "CascadNet": y = (y > 0.8) elif args.choose_net == "my_drsnet_A": y = (y > 0.5) elif args.choose_net == "my_drsnet_B": y = (y > 0.5) elif args.choose_net == "my_drsnet_C": y = (y > 0.5) elif args.choose_net == "my_drsnet_A_direct_skip": y = (y > 0.5) elif args.choose_net == "SEResNet": y = (y > 0.5) elif args.choose_net == "resnext_unet": y = (y > 0.5) elif args.choose_net == "resnet50_unet": y = (y > 0.5) elif args.choose_net == "unet_res34": y = (y > 0.5) elif args.choose_net == "dfanet": y = (y > 0.5) elif args.choose_net == "cgnet": y = (y > 0.5) elif args.choose_net == "lednet": y = (y > 0.5) elif args.choose_net == "bisenet": y = (y > 0.5) elif args.choose_net == "pspnet": y = (y > 0.5) elif args.choose_net == "fddwnet": y = (y > 0.5) elif args.choose_net == "contextnet": y = (y > 0.5) elif args.choose_net == "linknet": y = (y > 0.5) elif args.choose_net == "edanet": y = (y > 0.5) elif args.choose_net == "erfnet": y = (y > 0.5) img_y = y.astype(int).squeeze() print(y_label.shape, img_y.shape) image = np.concatenate((img_y, y_label)) y_label = y_label.astype(int) metric.addBatch(img_y, y_label) acc = metric.classPixelAccuracy() mIoU = metric.meanIntersectionOverUnion() # confusionMatrix=metric.genConfusionMatrix(img_y, y_label) mean_acc.append(acc[1]) mean_miou.append(mIoU) # print(acc, mIoU,confusionMatrix) print(acc, mIoU) plt.imshow(image * 5) plt.pause(0.1) plt.show() # 计算时需封印acc和miou计算部分 print("Took ", alltime, "seconds") print("Took", alltime / 638.0, "s/perimage") print("FPS", 1 / (alltime / 638.0)) print("average acc:%0.6f average miou:%0.6f" % (np.mean(mean_acc), np.mean(mean_miou)))
def test_img(src_path, label_path): model_enet = enet.ENet(num_classes=1).to(device) model_segnet = segnet.SegNet(3, 1).to(device) model_my_mulSE_A = my_drsnet.MultiscaleSENetA(3, 1).to(device) model_my_mulSE_B = my_drsnetmy_drsnet.MultiscaleSENetB(3, 1).to(device) model_my_mulSE_C = my_drsnet.MultiscaleSENetC(3, 1).to(device) model_my_mulSE_A_direct_skip = my_drsnet.MultiscaleSENetA_direct_skip( 3, 1).to(device) model_SEResNet18 = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device) ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]] model_dfanet = dfanet.DFANet(ch_cfg, 3, 1).to(device) model_cgnet = cgnet.Context_Guided_Network(1).to(device) model_lednet = lednet.Net(num_classes=1).to(device) model_bisenet = bisenet.BiSeNet(1, 'resnet18').to(device) model_fddwnet = fddwnet.Net(classes=1).to(device) model_contextnet = contextnet.ContextNet(classes=1).to(device) model_linknet = linknet.LinkNet(classes=1).to(device) model_edanet = edanet.EDANet(classes=1).to(device) model_erfnet = erfnet.ERFNet(classes=1).to(device) model_enet.load_state_dict(torch.load("./weight/enet_weight.pth")) model_enet.eval() model_segnet.load_state_dict(torch.load("./weight/segnet_weight.pth")) model_segnet.eval() model_my_mulSE_A.load_state_dict( torch.load("./weight/my_drsnet_A_weight.pth")) model_my_mulSE_A.eval() model_my_mulSE_B.load_state_dict( torch.load("./weight/my_drsnet_B_weight.pth")) model_my_mulSE_B.eval() model_my_mulSE_C.load_state_dict( torch.load("./weight/my_drsnet_C_weight.pth")) model_my_mulSE_C.eval() model_my_mulSE_A_direct_skip.load_state_dict( torch.load("./weight/my_drsnet_A_direct_skip_weight.pth")) model_my_mulSE_A_direct_skip.eval() model_SEResNet18.load_state_dict( torch.load("./weight/SEResNet18_weight.pth")) model_SEResNet18.eval() model_dfanet.load_state_dict(torch.load("./weight/dfanet.pth")) model_dfanet.eval() model_cgnet.load_state_dict(torch.load("./weight/cgnet.pth")) model_cgnet.eval() model_lednet.load_state_dict(torch.load("./weight/lednet.pth")) model_lednet.eval() model_bisenet.load_state_dict(torch.load("./weight/bisenet.pth")) model_bisenet.eval() model_fddwnet.load_state_dict(torch.load("./weight/fddwnet.pth")) model_fddwnet.eval() model_contextnet.load_state_dict(torch.load("./weight/contextnet.pth")) model_contextnet.eval() model_linknet.load_state_dict(torch.load("./weight/linknet.pth")) model_linknet.eval() model_edanet.load_state_dict(torch.load("./weight/edanet.pth")) model_edanet.eval() model_erfnet.load_state_dict(torch.load("./weight/erfnet.pth")) model_erfnet.eval() src = Image.open(src_path) src = src.resize((128, 192)) src = x_transform(src) src = src.to(device) src = torch.unsqueeze(src, 0) y_enet = model_enet(src) # label = label.to(device) y_enet = y_enet.cpu() y_enet = y_enet.detach().numpy().reshape(192, 128) y_segnet = model_segnet(src) # label = label.to(device) y_segnet = y_segnet.cpu() y_segnet = y_segnet.detach().numpy().reshape(192, 128) y_my_mulSE_A = model_my_mulSE_A(src) # label = label.to(device) y_my_mulSE_A = y_my_mulSE_A.cpu() y_my_mulSE_A = y_my_mulSE_A.detach().numpy().reshape(192, 128) y_my_mulSE_B = model_my_mulSE_B(src) # label = label.to(device) y_my_mulSE_B = y_my_mulSE_B.cpu() y_my_mulSE_B = y_my_mulSE_B.detach().numpy().reshape(192, 128) y_my_mulSE_C = model_my_mulSE_C(src) # label = label.to(device) y_my_mulSE_C = y_my_mulSE_C.cpu() y_my_mulSE_C = y_my_mulSE_C.detach().numpy().reshape(192, 128) y_my_mulSE_A_direct_skip = model_my_mulSE_A_direct_skip(src) # label = label.to(device) y_my_mulSE_A_direct_skip = y_my_mulSE_A_direct_skip.cpu() y_my_mulSE_A_direct_skip = y_my_mulSE_A_direct_skip.detach().numpy( ).reshape(192, 128) y_SEResNet18 = model_SEResNet18(src) # label = label.to(device) y_SEResNet18 = y_SEResNet18.cpu() y_SEResNet18 = y_SEResNet18.detach().numpy().reshape(192, 128) y_dfanet = model_dfanet(src) # label = label.to(device) y_dfanet = y_dfanet.cpu() y_dfanet = y_dfanet.detach().numpy().reshape(192, 128) y_cgnet = model_cgnet(src) # label = label.to(device) y_cgnet = y_cgnet.cpu() y_cgnet = y_cgnet.detach().numpy().reshape(192, 128) y_lednet = model_lednet(src) # label = label.to(device) y_lednet = y_lednet.cpu() y_lednet = y_lednet.detach().numpy().reshape(192, 128) y_bisenet = model_bisenet(src) # label = label.to(device) y_bisenet = y_bisenet.cpu() y_bisenet = y_bisenet.detach().numpy().reshape(192, 128) y_fddwnet = model_fddwnet(src) # label = label.to(device) y_fddwnet = y_fddwnet.cpu() y_fddwnet = y_fddwnet.detach().numpy().reshape(192, 128) y_contextnet = model_contextnet(src) # label = label.to(device) y_contextnet = y_contextnet.cpu() y_contextnet = y_contextnet.detach().numpy().reshape(192, 128) y_linknet = model_linknet(src) # label = label.to(device) y_linknet = y_linknet.cpu() y_linknet = y_linknet.detach().numpy().reshape(192, 128) y_edanet = model_edanet(src) # label = label.to(device) y_edanet = y_edanet.cpu() y_edanet = y_edanet.detach().numpy().reshape(192, 128) y_erfnet = model_erfnet(src) # label = label.to(device) y_erfnet = y_erfnet.cpu() y_erfnet = y_erfnet.detach().numpy().reshape(192, 128) y_enet = (y_enet > 0.5).astype(int) * 255 y_segnet = (y_segnet > 0.5).astype(int) * 255 y_my_mulSE_A = (y_my_mulSE_A > 0.5).astype(int) * 255 y_my_mulSE_B = (y_my_mulSE_B > 0.5).astype(int) * 255 y_my_mulSE_C = (y_my_mulSE_C > 0.5).astype(int) * 255 y_my_mulSE_A_direct_skip = (y_my_mulSE_A_direct_skip > 0.5).astype(int) * 255 y_SEResNet18 = (y_SEResNet18 > 0.5).astype(int) * 255 y_dfanet = (y_dfanet > 0.5).astype(int) * 255 y_cgnet = (y_cgnet > 0.5).astype(int) * 255 y_lednet = (y_lednet > 0.5).astype(int) * 255 y_bisenet = (y_bisenet > 0.5).astype(int) * 255 y_fddwnet = (y_fddwnet > 0.5).astype(int) * 255 y_contextnet = (y_contextnet > 0.5).astype(int) * 255 y_linknet = (y_linknet > 0.5).astype(int) * 255 y_edanet = (y_edanet > 0.5).astype(int) * 255 y_erfnet = (y_erfnet > 0.5).astype(int) * 255 src1 = Image.open(src_path) src1 = src1.resize((128, 192)) label = Image.open(label_path) label = label.resize((128, 192)) label = np.array(label) * 255 src1.save("./data/result/" + "_src.png") cv2.imwrite("./data/result/" + "_label.png", label) cv2.imwrite("./data/result/" + "enet_predict.png", y_enet) cv2.imwrite("./data/result/" + "segnet_predict.png", y_segnet) cv2.imwrite("./data/result/" + "my_drsnet_A_predict.png", y_my_mulSE_A) cv2.imwrite("./data/result/" + "my_drsnet_B_predict.png", y_my_mulSE_B) cv2.imwrite("./data/result/" + "my_drsnet_C_predict.png", y_my_mulSE_C) cv2.imwrite("./data/result/" + "my_drsnet_A_direct_skip_predict.png", y_my_mulSE_A_direct_skip) cv2.imwrite("./data/result/" + "y_SEResNet18_predict.png", y_SEResNet18) cv2.imwrite("./data/result/" + "dfanet_predict.png", y_dfanet) cv2.imwrite("./data/result/" + "cgnet_predict.png", y_cgnet) cv2.imwrite("./data/result/" + "lednet_predict.png", y_lednet) cv2.imwrite("./data/result/" + "bisenet_predict.png", y_bisenet) cv2.imwrite("./data/result/" + "fddwnet_predict.png", y_fddwnet) cv2.imwrite("./data/result/" + "contextnet_predict.png", y_contextnet) cv2.imwrite("./data/result/" + "linknet_predict.png", y_linknet) cv2.imwrite("./data/result/" + "edanet_predict.png", y_edanet) cv2.imwrite("./data/result/" + "erfnet_predict.png", y_erfnet) return 0
def train(): if args.choose_net == "Unet": model = my_unet.UNet(3, 1).to(device) if args.choose_net == "My_Unet": model = my_unet.My_Unet2(3, 1).to(device) elif args.choose_net == "Enet": model = enet.ENet(num_classes=13).to(device) elif args.choose_net == "Segnet": model = segnet.SegNet(3, 13).to(device) elif args.choose_net == "CascadNet": model = my_cascadenet.CascadeNet(3, 1).to(device) elif args.choose_net == "my_drsnet_A": model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_B": model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_C": model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_A_direct_skip": model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3, out_ch=1).to(device) elif args.choose_net == "SEResNet": model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device) elif args.choose_net == "resnext_unet": model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device) elif args.choose_net == "resnet50_unet": model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3, out_ch=1).to(device) elif args.choose_net == "unet_nest": model = unet_nest.UNet_Nested(3, 2).to(device) elif args.choose_net == "unet_res34": model = unet_res34.Resnet_Unet(3, 1).to(device) elif args.choose_net == "trangle_net": model = mytrangle_net.trangle_net(3, 1).to(device) elif args.choose_net == "dfanet": ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]] model = dfanet.DFANet(ch_cfg, 3, 1).to(device) elif args.choose_net == "lednet": model = lednet.Net(num_classes=1).to(device) elif args.choose_net == "cgnet": model = cgnet.Context_Guided_Network(classes=1).to(device) elif args.choose_net == "pspnet": model = pspnet.PSPNet(1).to(device) elif args.choose_net == "bisenet": model = bisenet.BiSeNet(1, 'resnet18').to(device) elif args.choose_net == "espnet": model = espnet.ESPNet(classes=1).to(device) elif args.choose_net == "fddwnet": model = fddwnet.Net(classes=1).to(device) elif args.choose_net == "contextnet": model = contextnet.ContextNet(classes=1).to(device) elif args.choose_net == "linknet": model = linknet.LinkNet(classes=1).to(device) elif args.choose_net == "edanet": model = edanet.EDANet(classes=1).to(device) elif args.choose_net == "erfnet": model = erfnet.ERFNet(classes=1).to(device) from collections import OrderedDict loadpretrained = 0 # 0:no loadpretrained model # 1:loadpretrained model to original network # 2:loadpretrained model to new network if loadpretrained == 1: model.load_state_dict(torch.load(args.weight)) elif loadpretrained == 2: model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device) model_dict = model.state_dict() pretrained_dict = torch.load(args.weight) pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) # model.load_state_dict(torch.load(args.weight)) # pretrained_dict = {k: v for k, v in model.items() if k in model} # filter out unnecessary keys # model.update(pretrained_dict) # model.load_state_dict(model) # 计算模型参数量和计算量FLOPs dsize = (1, 3, 128, 192) inputs = torch.randn(dsize).to(device) total_ops, total_params = profile(model, (inputs, ), verbose=False) print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3))) batch_size = args.batch_size # 加载数据集 liver_dataset = LiverDataset("data/train_camvid/", transform=x_transform, target_transform=y_transform) len_img = liver_dataset.__len__() dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=24) # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用 # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度 # 梯度下降 # optimizer = optim.Adam(model.parameters()) # model.parameters():Returns an iterator over module parameters # # Observe that all parameters are being optimized optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001) # 每n个epoches来一次余弦退火 cosine_lr_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=10 * int(len_img / batch_size), eta_min=0.00001) multiclass = 1 if multiclass == 1: # 损失函数 class_weights = np.array([ 0., 6.3005947, 4.31063664, 34.09234699, 50.49834979, 3.88280945, 50.49834979, 8.91626081, 47.58477105, 29.41289083, 18.95706775, 37.84558871, 39.3477858 ]) #camvid # class_weights = weighing(dataloader, 13, c=1.02) class_weights = torch.from_numpy(class_weights).float().to(device) criterion = torch.nn.CrossEntropyLoss(weight=class_weights) # criterion = LovaszLossSoftmax() # criterion = torch.nn.MSELoss() train_modelmulticlasses(model, criterion, optimizer, dataloader, cosine_lr_scheduler) else: # 损失函数 # criterion = LovaszLossHinge() # weights=[0.2] # weights=torch.Tensor(weights).to(device) # # criterion = torch.nn.CrossEntropyLoss(weight=weights) criterion = torch.nn.BCELoss() # criterion =focal_loss.FocalLoss(1) train_model(model, criterion, optimizer, dataloader, cosine_lr_scheduler)
n_classes = args.nClasses images_path = '../../datasets/segmentation/' val_file = './data/seg_test.txt' if n_classes == 2 else './data/parse_test.txt' weights_file = './weights/{}_seg_weights.h5'.format(args.model) if n_classes == 2 \ else './weights/{}_parse_weights.h5'.format(args.model) input_height = 256 input_width = 256 if args.model == 'unet': m = unet.Unet(n_classes, input_height=input_height, input_width=input_width) elif args.model == 'segnet': m = segnet.SegNet(n_classes, input_height=input_height, input_width=input_width) else: raise ValueError('Do not support {}'.format(args.model)) m.load_weights(weights_file.format(args.model)) m.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) print('Start evaluating..') pbdr = tqdm(total=5000) iou = [0. for _ in range(1, n_classes)] count = [0. for _ in range(1, n_classes)] for x, y in generator(images_path, val_file,
def __init__(self, args): self.model = segnet.SegNet( args.in_ch, args.out_ch, args.base_kernel)
def test_backward(self): opt = optimizers.MomentumSGD(lr=0.01) # Update each depth for depth in six.moves.range(1, self.n_encdec + 1): model = segnet.SegNet(self.n_encdec, self.n_classes, self.x_shape[1], self.n_mid) model = segnet.SegNetLoss(model, class_weight=None, train_depth=depth) opt.setup(model) # Deregister non-target links from opt if depth > 1: model.predictor.remove_link('conv_cls') for d in range(1, self.n_encdec + 1): if d != depth: model.predictor.remove_link('encdec{}'.format(d)) # Keep the initial values prev_params = { 'conv_cls': copy.deepcopy(model.predictor.conv_cls.W.data) } for d in range(1, self.n_encdec + 1): name = '/encdec{}/enc/W'.format(d) encdec = getattr(model.predictor, 'encdec{}'.format(d)) prev_params[name] = copy.deepcopy(encdec.enc.W.data) self.assertTrue(prev_params[name] is not encdec.enc.W.data) # Update the params x, t = self.get_xt() loss = model(x, t) loss.data *= 1E20 model.cleargrads() loss.backward() opt.update() for d in range(1, self.n_encdec + 1): # The weight only in the target layer should be updated c = self.assertFalse if d == depth else self.assertTrue encdec = getattr(opt.target.predictor, 'encdec{}'.format(d)) self.assertTrue(hasattr(encdec, 'enc')) self.assertTrue(hasattr(encdec.enc, 'W')) self.assertTrue('/encdec{}/enc/W'.format(d) in prev_params) c(np.array_equal(encdec.enc.W.data, prev_params['/encdec{}/enc/W'.format(d)]), msg='depth:{} d:{} diff:{}'.format( depth, d, np.sum(encdec.enc.W.data - prev_params['/encdec{}/enc/W'.format(d)]))) if depth == 1: # The weight in the last layer should be updated self.assertFalse( np.allclose(model.predictor.conv_cls.W.data, prev_params['conv_cls'])) cg = build_computational_graph([loss], variable_style=_var_style, function_style=_func_style).dump() fn = 'tests/SegNet_bw_depth-{}_{}.dot'.format(self.n_encdec, depth) if os.path.exists(fn): continue with open(fn, 'w') as f: f.write(cg) subprocess.call('dot -Tpng {} -o {}'.format( fn, fn.replace('.dot', '.png')), shell=True) for name, param in model.namedparams(): encdec_depth = re.search('encdec([0-9]+)', name) if encdec_depth: ed = int(encdec_depth.groups()[0]) self.assertEqual(ed, depth)