def validate(args): if args.vis: vis = visdom.Visdom() if args.dataset_path == '': HOME_PATH = os.path.expanduser('~') local_path = os.path.join(HOME_PATH, 'Data/CamVid') else: local_path = args.dataset_path dst = camvidLoader(local_path, is_transform=True, split='val') valloader = torch.utils.data.DataLoader(dst, batch_size=1) # if os.path.isfile(args.validate_model): if args.validate_model != '': model = torch.load(args.validate_model) else: if args.structure == 'fcn32s': model = fcn(module_type='32s', n_classes=dst.n_classes) elif args.structure == 'fcn16s': model = fcn(module_type='16s', n_classes=dst.n_classes) elif args.structure == 'fcn8s': model = fcn(module_type='8s', n_classes=dst.n_classes) elif args.structure == 'ResNetDUC': model = ResNetDUC(n_classes=dst.n_classes) elif args.structure == 'segnet': model = segnet(n_classes=dst.n_classes) elif args.structure == 'ENet': model = ENet(n_classes=dst.n_classes) elif args.structure == 'drn_d_22': model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes) elif args.structure == 'pspnet': model = pspnet(n_classes=dst.n_classes, use_aux=False) elif args.structure == 'erfnet': model = erfnet(n_classes=dst.n_classes) if args.validate_model_state_dict != '': try: model.load_state_dict(torch.load(args.validate_model_state_dict)) except KeyError: print('missing key') model.eval() gts, preds = [], [] for i, (imgs, labels) in enumerate(valloader): print(i) # print(labels.shape) # print(imgs.shape) # 将np变量转换为pytorch中的变量 imgs = Variable(imgs) labels = Variable(labels) outputs = model(imgs) # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width, # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label pred = outputs.data.max(1)[1].numpy() gt = labels.data.numpy() # print(pred.dtype) # print(gt.dtype) # print('pred.shape:', pred.shape) # print('gt.shape:', gt.shape) if args.vis and i % 50 == 0: img = imgs.data.numpy()[0] # print(img.shape) label_color = dst.decode_segmap(gt[0]).transpose(2, 0, 1) # print(label_color.shape) pred_label_color = dst.decode_segmap(pred[0]).transpose(2, 0, 1) # print(pred_label_color.shape) # try: # win = 'label_color' # vis.image(label_color, win=win) # win = 'pred_label_color' # vis.image(pred_label_color, win=win) # except ConnectionError: # print('ConnectionError') if args.blend: img_hwc = img.transpose(1, 2, 0) img_hwc = img_hwc*255.0 img_hwc += np.array([104.00699, 116.66877, 122.67892]) img_hwc = np.array(img_hwc, dtype=np.uint8) # label_color_hwc = label_color.transpose(1, 2, 0) pred_label_color_hwc = pred_label_color.transpose(1, 2, 0) pred_label_color_hwc = np.array(pred_label_color_hwc, dtype=np.uint8) # print(img_hwc.dtype) # print(pred_label_color_hwc.dtype) label_blend = img_hwc * 0.5 + pred_label_color_hwc * 0.5 label_blend = np.array(label_blend, dtype=np.uint8) misc.imsave('/tmp/label_blend.png', label_blend) for gt_, pred_ in zip(gt, pred): gts.append(gt_) preds.append(pred_) score, class_iou = scores(gts, preds, n_class=dst.n_classes) for k, v in score.items(): print(k, v) for i in range(dst.n_classes): print(i, class_iou[i])
def validate(args): init_time = str(int(time.time())) if args.vis: vis = visdom.Visdom() if args.dataset_path == '': HOME_PATH = os.path.expanduser('~') local_path = os.path.join(HOME_PATH, 'Data/CamVid') else: local_path = args.dataset_path local_path = os.path.expanduser(args.dataset_path) if args.dataset == 'CamVid': dst = camvidLoader(local_path, is_transform=True, split=args.dataset_type) elif args.dataset == 'CityScapes': dst = cityscapesLoader(local_path, is_transform=True) else: pass valloader = torch.utils.data.DataLoader(dst, batch_size=1) # if os.path.isfile(args.validate_model): if args.validate_model != '': model = torch.load(args.validate_model) else: if args.structure == 'fcn32s': model = fcn(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn16s': model = fcn(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn8s': model = fcn(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet18_32s': model = fcn_resnet18(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet18_16s': model = fcn_resnet18(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet18_8s': model = fcn_resnet18(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet34_32s': model = fcn_resnet34(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet34_16s': model = fcn_resnet34(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet34_8s': model = fcn_resnet34(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_MobileNet_32s': model = fcn_MobileNet(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_MobileNet_16s': model = fcn_MobileNet(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_MobileNet_8s': model = fcn_MobileNet(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ResNetDUC': model = ResNetDUC(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ResNetDUCHDC': model = ResNetDUCHDC(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet': model = segnet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet_vgg19': model = segnet_vgg19(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet_unet': model = segnet_unet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet_alignres': model = segnet_alignres(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'sqnet': model = sqnet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet_squeeze': model = segnet_squeeze(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ENet': model = ENet(n_classes=dst.n_classes) elif args.structure == 'ENetV2': model = ENetV2(n_classes=dst.n_classes) elif args.structure == 'drn_d_22': model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'drn_a_50': model = DRNSeg(model_name='drn_a_50', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'drn_a_18': model = DRNSeg(model_name='drn_a_18', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'drn_e_22': model = DRNSeg(model_name='drn_e_22', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'erfnet': model = erfnet(n_classes=dst.n_classes) elif args.structure == 'fcdensenet103': model = fcdensenet103(n_classes=dst.n_classes) elif args.structure == 'fcdensenet56': model = fcdensenet56(n_classes=dst.n_classes) elif args.structure == 'fcdensenet_tiny': model = fcdensenet_tiny(n_classes=dst.n_classes) elif args.structure == 'Res_Deeplab_101': model = Res_Deeplab_101(n_classes=dst.n_classes) elif args.structure == 'Res_Deeplab_50': model = Res_Deeplab_50(n_classes=dst.n_classes) elif args.structure == 'EDANet': model = EDANet(n_classes=dst.n_classes) elif args.structure == 'drn_a_asymmetric_18': model = DRNSeg(model_name='drn_a_asymmetric_18', n_classes=dst.n_classes, pretrained=False) if args.validate_model_state_dict != '': try: model.load_state_dict( torch.load(args.validate_model_state_dict)) except KeyError: print('missing key') if args.cuda: model.cuda() model.eval() gts, preds = [], [] for i, (imgs, labels) in enumerate(valloader): print(i) # print(labels.shape) # print(imgs.shape) # 将np变量转换为pytorch中的变量 imgs = Variable(imgs) labels = Variable(labels) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() outputs = model(imgs) # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width, # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label pred = outputs.cpu().data.max(1)[1].numpy() gt = labels.cpu().data.numpy() # print(pred.dtype) # print(gt.dtype) # print('pred.shape:', pred.shape) # print('gt.shape:', gt.shape) # if args.vis and i % 1 == 0: # img = imgs.cpu().data.numpy()[0] # # print(img.shape) # label_color = dst.decode_segmap(gt[0]).transpose(2, 0, 1) # # print(label_color.shape) # pred_label_color = dst.decode_segmap(pred[0]).transpose(2, 0, 1) # # print(pred_label_color.shape) # # try: # # win = 'label_color' # # vis.image(label_color, win=win) # # win = 'pred_label_color' # # vis.image(pred_label_color, win=win) # # except ConnectionError: # # print('ConnectionError') # # # if args.blend: # img_hwc = img.transpose(1, 2, 0) # img_hwc = img_hwc*255.0 # img_hwc += np.array([104.00699, 116.66877, 122.67892]) # img_hwc = np.array(img_hwc, dtype=np.uint8) # # label_color_hwc = label_color.transpose(1, 2, 0) # pred_label_color_hwc = pred_label_color.transpose(1, 2, 0) # pred_label_color_hwc = np.array(pred_label_color_hwc, dtype=np.uint8) # # print(img_hwc.dtype) # # print(pred_label_color_hwc.dtype) # label_blend = img_hwc * 0.5 + pred_label_color_hwc * 0.5 # label_blend = np.array(label_blend, dtype=np.uint8) # # if not os.path.exists('/tmp/' + init_time): # os.mkdir('/tmp/' + init_time) # time_str = str(int(time.time())) # # misc.imsave('/tmp/'+init_time+'/'+time_str+'_label_blend.png', label_blend) for gt_, pred_ in zip(gt, pred): gts.append(gt_) preds.append(pred_) score, class_iou = scores(gts, preds, n_class=dst.n_classes) for k, v in score.items(): print(k, v) for i in range(dst.n_classes): print(i, class_iou[i])
def train(args): init_time = str(int(time.time())) if args.vis: vis = visdom.Visdom() if args.dataset_path == '': HOME_PATH = os.path.expanduser('~') local_path = os.path.join(HOME_PATH, 'Data/CamVid') else: local_path = args.dataset_path dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment) dst.n_classes = args.n_classes # 保证输入的class trainloader = torch.utils.data.DataLoader(dst, batch_size=args.batch_size, shuffle=True) start_epoch = 0 if args.resume_model != '': model = torch.load(args.resume_model) start_epoch_id1 = args.resume_model.rfind('_') start_epoch_id2 = args.resume_model.rfind('.') start_epoch = int(args.resume_model[start_epoch_id1+1:start_epoch_id2]) else: if args.structure == 'fcn32s': model = fcn(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn16s': model = fcn(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn8s': model = fcn(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ResNetDUC': model = ResNetDUC(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet': model = segnet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ENet': model = ENet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'drn_d_22': model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'pspnet': model = pspnet(n_classes=dst.n_classes, pretrained=args.init_vgg16, use_aux=False) elif args.structure == 'erfnet': model = erfnet(n_classes=dst.n_classes) if args.resume_model_state_dict != '': try: # fcn32s、fcn16s和fcn8s模型略有增加参数,互相赋值重新训练过程中会有KeyError,暂时捕捉异常处理 start_epoch_id1 = args.resume_model_state_dict.rfind('_') start_epoch_id2 = args.resume_model_state_dict.rfind('.') start_epoch = int(args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2]) pretrained_dict = torch.load(args.resume_model_state_dict) # model_dict = model.state_dict() # for k, v in pretrained_dict.items(): # print(k) # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) except KeyError: print('missing key') if args.cuda: model.cuda() print('start_epoch:', start_epoch) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4) for epoch in range(start_epoch+1, 20000, 1): loss_epoch = 0 loss_avg_epoch = 0 data_count = 0 # if args.vis: # vis.text('epoch:{}'.format(epoch), win='epoch') for i, (imgs, labels) in enumerate(trainloader): print(i) data_count = i # print(labels.shape) # print(imgs.shape) imgs = Variable(imgs) labels = Variable(labels) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() outputs = model(imgs) if args.vis and i%50==0: pred_labels = outputs.data.max(1)[1].numpy() # print(pred_labels.shape) label_color = dst.decode_segmap(labels.data.numpy()[0]).transpose(2, 0, 1) # print(label_color.shape) pred_label_color = dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1) # print(pred_label_color.shape) win = 'label_color' vis.image(label_color, win=win) win = 'pred_label_color' vis.image(pred_label_color, win=win) if epoch < 100: if not os.path.exists('/tmp/'+init_time): os.mkdir('/tmp/'+init_time) time_str = str(int(time.time())) print('label_color.transpose(2, 0, 1).shape:', label_color.transpose(1, 2, 0).shape) print('pred_label_color.transpose(2, 0, 1).shape:', pred_label_color.transpose(1, 2, 0).shape) cv2.imwrite('/tmp/'+init_time+'/'+time_str+'_label.png', label_color.transpose(1, 2, 0)) cv2.imwrite('/tmp/'+init_time+'/'+time_str+'_pred_label.png', pred_label_color.transpose(1, 2, 0)) # print(outputs.size()) # print(labels.size()) # 一次backward后如果不清零,梯度是累加的 optimizer.zero_grad() loss = cross_entropy2d(outputs, labels) loss_numpy = loss.cpu().data.numpy() loss_epoch += loss_numpy print('loss:', loss_numpy) loss.backward() optimizer.step() # 显示一个周期的loss曲线 if args.vis: win = 'loss' win_res = vis.line(X=np.ones(1)*i, Y=loss.data.numpy(), win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*i, Y=loss.data.numpy(), win=win) # 关闭清空一个周期的loss if args.vis: win = 'loss' vis.close(win) # 显示多个周期的loss曲线 loss_avg_epoch = loss_epoch / (data_count * 1.0) # print(loss_avg_epoch) if args.vis: win = 'loss_epoch' win_res = vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch, win=win) if args.save_model and epoch%args.save_epoch==0: torch.save(model.state_dict(), '{}_camvid_class_{}_{}.pt'.format(args.structure, dst.n_classes, epoch))