def validate(args): init_time = str(int(time.time())) if args.vis: vis = visdom.Visdom() if args.dataset_path == '': HOME_PATH = os.path.expanduser('~') local_path = os.path.join(HOME_PATH, 'Data/CamVid') else: local_path = args.dataset_path local_path = os.path.expanduser(args.dataset_path) if args.dataset == 'CamVid': dst = camvidLoader(local_path, is_transform=True, split=args.dataset_type) elif args.dataset == 'CityScapes': dst = cityscapesLoader(local_path, is_transform=True, split=args.dataset_type) else: pass val_loader = torch.utils.data.DataLoader(dst, batch_size=1, shuffle=False) # if os.path.isfile(args.validate_model): if args.validate_model != '': model = torch.load(args.validate_model) else: try: model = eval(args.structure)(n_classes=args.n_classes, pretrained=args.init_vgg16) except: print('missing structure or not support') exit(0) if args.validate_model_state_dict != '': try: model.load_state_dict(torch.load(args.validate_model_state_dict, map_location='cpu')) except KeyError: print('missing key') if args.cuda: model.cuda() model.eval() gts, preds, errors, imgs_name = [], [], [], [] for i, (imgs, labels) in enumerate(val_loader): print(i) if i==1: break img_path = dst.files[args.dataset_type][i] img_name = img_path[img_path.rfind('/')+1:] imgs_name.append(img_name) # print('img_path:', img_path) # print('img_name:', img_name) # print(labels.shape) # print(imgs.shape) # 将np变量转换为pytorch中的变量 imgs = Variable(imgs) labels = Variable(labels) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() outputs = model(imgs) loss = cross_entropy2d(outputs, labels) loss_np = loss.cpu().data.numpy() loss_np_float = float(loss_np) # print('loss_np_float:', loss_np_float) errors.append(loss_np_float) # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width, # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label pred = outputs.cpu().data.max(1)[1].numpy() gt = labels.cpu().data.numpy() if args.save_result: if not os.path.exists('/tmp/'+init_time): os.mkdir('/tmp/'+init_time) pred_labels = outputs.cpu().data.max(1)[1].numpy() label_color = dst.decode_segmap(labels.cpu().data.numpy()[0]).transpose(2, 0, 1) pred_label_color = dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1) # label_color_cv2 = label_color.transpose(1, 2, 0) # label_color_cv2 = cv2.cvtColor(label_color_cv2, cv2.COLOR_RGB2BGR) # cv2.imwrite('/tmp/'+init_time+'/{}'.format(img_name), label_color_cv2) pred_label_color_cv2 = pred_label_color.transpose(1, 2, 0) pred_label_color_cv2 = cv2.cvtColor(pred_label_color_cv2, cv2.COLOR_RGB2BGR) cv2.imwrite('/tmp/'+init_time+'/{}'.format(img_name), pred_label_color_cv2) for gt_, pred_ in zip(gt, pred): gts.append(gt_) preds.append(pred_) # print('errors:', errors) # print('imgs_name:', imgs_name) errors_indices = np.argsort(errors).tolist() # print('errors_indices:', errors_indices) # for top_i in range(len(errors_indices)): for top_i in range(10): top_index = errors_indices.index(top_i) # print('top_index:', top_index) img_name_top = imgs_name[top_index] print('img_name_top:', img_name_top) score, class_iou = scores(gts, preds, n_class=dst.n_classes) for k, v in score.items(): print(k, v) for i in range(dst.n_classes): print(i, class_iou[i])
def train(args): def type_callback(event): # print('event_type:{}'.format(event['event_type'])) if event['event_type'] == 'KeyPress': event_key = event['key'] if event_key == 'Enter': pass # print('event_type:Enter') elif event_key == 'Backspace': pass # print('event_type:Backspace') elif event_key == 'Delete': pass # print('event_type:Delete') elif len(event_key) == 1: pass # print('event_key:{}'.format(event['key'])) if event_key=='s': import json win = 'loss_iteration' win_data = vis.get_window_data(win) win_data_dict = json.loads(win_data) win_data_content_dict = win_data_dict['content'] win_data_x = np.array(win_data_content_dict['data'][0]['x']) win_data_y = np.array(win_data_content_dict['data'][0]['y']) win_data_save_file = '/tmp/loss_iteration_{}.txt'.format(init_time) with open(win_data_save_file, 'wb') as f: for item_x, item_y in zip(win_data_x, win_data_y): f.write("{} {}\n".format(item_x, item_y)) done_time = str(int(time.time())) vis.text(vis_text_usage+'done at {}'.format(done_time), win=callback_text_usage_window) init_time = str(int(time.time())) if args.vis: # start visdom and close all window vis = visdom.Visdom() vis.close() vis_text_usage = 'Operating in the text window<br>Press s to save data<br>' callback_text_usage_window = vis.text(vis_text_usage) vis.register_event_handler(type_callback, callback_text_usage_window) class_weight = None local_path = os.path.expanduser(args.dataset_path) train_dst = None val_dst = None if args.dataset == 'CamVid': train_dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment, split='train') val_dst = camvidLoader(local_path, is_transform=True, is_augment=False, split='val') trainannot_image_dir = os.path.expanduser(os.path.join(local_path, "trainannot")) trainannot_image_files = [os.path.join(trainannot_image_dir, file) for file in os.listdir(trainannot_image_dir) if file.endswith('.png')] if args.class_weighting=='MFB': class_weight = median_frequency_balancing(trainannot_image_files, num_classes=12) class_weight = torch.tensor(class_weight) elif args.class_weighting=='ENET': class_weight = ENet_weighing(trainannot_image_files, num_classes=12) class_weight = torch.tensor(class_weight) elif args.dataset == 'CityScapes': train_dst = cityscapesLoader(local_path, is_transform=True, split='train') val_dst = cityscapesLoader(local_path, is_transform=True, split='val') else: print('{} dataset does not implement'.format(args.dataset)) exit(0) if args.cuda: if class_weight is not None: class_weight = class_weight.cuda() print('class_weight:', class_weight) train_loader = torch.utils.data.DataLoader(train_dst, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_dst, batch_size=1, shuffle=True) yolo_B = 2 yolo_C = 2 yolo_S = 7 yolo_out_tensor_shape = yolo_B * 5 + yolo_C print('yolo_out_tensor_shape:', yolo_out_tensor_shape) det_criterion = yoloLoss(yolo_S, yolo_B, yolo_C, 5, 0.5, args.cuda) det_file_root = os.path.expanduser('~/Data/CamVid/train/') det_train_dst = yoloDataset(root=det_file_root, list_file=['camvid_det.txt'], train=True, transform=[transforms.ToTensor()], yolo_out_tensor_shape=yolo_out_tensor_shape) det_train_loader = torch.utils.data.DataLoader(det_train_dst, batch_size=1, shuffle=True, num_workers=4) start_epoch = 0 best_mIoU = 0 if args.resume_model != '': model = torch.load(args.resume_model) start_epoch_id1 = args.resume_model.rfind('_') start_epoch_id2 = args.resume_model.rfind('.') start_epoch = int(args.resume_model[start_epoch_id1+1:start_epoch_id2]) else: model = drnsegmt_a_18(pretrained=args.init_vgg16, n_classes=args.n_classes, det_tensor_num=yolo_out_tensor_shape) if args.resume_model_state_dict != '': try: # from model save format get useful information, such as miou, epoch miou_model_name_str = '_miou_' class_model_name_str = '_class_' miou_id1 = args.resume_model_state_dict.find(miou_model_name_str)+len(miou_model_name_str) miou_id2 = args.resume_model_state_dict.find(class_model_name_str) best_mIoU = float(args.resume_model_state_dict[miou_id1:miou_id2]) start_epoch_id1 = args.resume_model_state_dict.rfind('_') start_epoch_id2 = args.resume_model_state_dict.rfind('.') start_epoch = int(args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2]) pretrained_dict = torch.load(args.resume_model_state_dict, map_location='cpu') model.load_state_dict(pretrained_dict) except KeyError: print('missing resume_model_state_dict or wrong type') if args.cuda: model.cuda() print('start_epoch:', start_epoch) print('best_mIoU:', best_mIoU) optimizer = None if args.solver == 'SGD': optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4) elif args.solver == 'RMSprop': optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4) elif args.solver == 'Adam': optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=5e-4) else: print('missing solver or not support') exit(0) # when observerd object dose not decrease scheduler will let the optimizer learing rate decrease # scheduler = ReduceLROnPlateau(optimizer, 'min', patience=100, min_lr=1e-10, verbose=True) scheduler = None if args.lr_policy == 'Constant': scheduler = ConstantLR(optimizer) elif args.lr_policy == 'Polynomial': scheduler = PolynomialLR(optimizer, max_iter=args.training_epoch, power=0.9) # base lr=0.01 power=0.9 like PSPNet # scheduler = StepLR(optimizer, step_size=1, gamma=0.1) data_count = int(train_dst.__len__() * 1.0 / args.batch_size) det_data_count = int(det_train_dst.__len__() * 1.0 / 1) print('data_count:', data_count) # iteration_step = 0 train_gts, train_preds = [], [] for epoch in range(start_epoch+1, args.training_epoch, 1): loss_epoch = 0 scheduler.step() # ----for object detection---- for det_i, (det_imgs, det_labels, _) in enumerate(det_train_loader): model.train() # print('det_imgs.shape:', det_imgs.shape) # print('det_labels.shape:', det_labels.shape) det_imgs = Variable(det_imgs) det_labels = Variable(det_labels) if args.cuda: det_imgs = det_imgs.cuda() det_labels = det_labels.cuda() _, outputs_det = model(det_imgs) # print('outpust_det:', outputs_det.shape) det_loss = det_criterion(outputs_det, det_labels) det_loss = 0.02 * det_loss # for balance with segment and detection det_loss_np = det_loss.cpu().data.numpy() optimizer.zero_grad() det_loss.backward() optimizer.step() # 显示一个周期的loss曲线 if args.vis: win = 'det_loss_iteration' det_loss_np_expand = np.expand_dims(det_loss_np, axis=0) win_res = vis.line(X=np.ones(1)*(det_i+det_data_count*(epoch-1)+1), Y=det_loss_np_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*(det_i+det_data_count*(epoch-1)+1), Y=det_loss_np_expand, win=win, opts=dict(title=win, xlabel='iteration', ylabel='loss')) # ----for object detection---- # ----for semantic segment---- for i, (imgs, labels) in enumerate(train_loader): # if i==1: # break model.train() # 最后的几张图片可能不到batch_size的数量,比如batch_size=4,可能只剩3张 imgs_batch = imgs.shape[0] if imgs_batch != args.batch_size: break # iteration_step += 1 imgs = Variable(imgs) labels = Variable(labels) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() outputs_sem, _ = model(imgs) # print('outputs_sem.shape:', outputs_sem.shape) # 一次backward后如果不清零,梯度是累加的 optimizer.zero_grad() # print('outputs.size:', outputs.size()) # print('labels.size:', labels.size()) loss = cross_entropy2d(outputs_sem, labels, weight=class_weight) loss_np = loss.cpu().data.numpy() loss_epoch += loss_np loss.backward() optimizer.step() # ------------------train metris------------------------------- train_pred = outputs_sem.cpu().data.max(1)[1].numpy() train_gt = labels.cpu().data.numpy() for train_gt_, train_pred_ in zip(train_gt, train_pred): train_gts.append(train_gt_) train_preds.append(train_pred_) # ------------------train metris------------------------------- if args.vis and i%50==0: pred_labels = outputs_sem.cpu().data.max(1)[1].numpy() label_color = train_dst.decode_segmap(labels.cpu().data.numpy()[0]).transpose(2, 0, 1) pred_label_color = train_dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1) win = 'label_color' vis.image(label_color, win=win, opts=dict(title='Gt', caption='Ground Truth')) win = 'pred_label_color' vis.image(pred_label_color, win=win, opts=dict(title='Pred', caption='Prediction')) # 显示一个周期的loss曲线 if args.vis: win = 'loss_iteration' loss_np_expand = np.expand_dims(loss_np, axis=0) win_res = vis.line(X=np.ones(1)*(i+data_count*(epoch-1)+1), Y=loss_np_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*(i+data_count*(epoch-1)+1), Y=loss_np_expand, win=win, opts=dict(title=win, xlabel='iteration', ylabel='loss')) # ----for semantic segment---- # val result on val dataset and pick best to save if args.val_interval > 0 and epoch % args.val_interval == 0: print('----starting val----') model.eval() val_gts, val_preds = [], [] for val_i, (val_imgs, val_labels) in enumerate(val_loader): # print(val_i) val_imgs = Variable(val_imgs) val_labels = Variable(val_labels) if args.cuda: val_imgs = val_imgs.cuda() val_labels = val_labels.cuda() val_outputs_sem, _ = model(val_imgs) val_pred = val_outputs_sem.cpu().data.max(1)[1].numpy() val_gt = val_labels.cpu().data.numpy() for val_gt_, val_pred_ in zip(val_gt, val_pred): val_gts.append(val_gt_) val_preds.append(val_pred_) score, class_iou = scores(val_gts, val_preds, n_class=args.n_classes) for k, v in score.items(): print(k, v) if k == 'Mean IoU : \t': v_iou = v if v > best_mIoU: best_mIoU = v_iou torch.save(model.state_dict(), '{}_{}_miou_{}_class_{}_{}.pt'.format(args.structure, args.dataset, best_mIoU, args.n_classes, epoch)) # 显示校准周期的mIoU if args.vis: win = 'mIoU_epoch' v_iou_expand = np.expand_dims(v_iou, axis=0) win_res = vis.line(X=np.ones(1)*epoch*args.val_interval, Y=v_iou_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*epoch*args.val_interval, Y=v_iou_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='mIoU')) # for class_i in range(args.n_classes): # print(class_i, class_iou[class_i]) print('----ending val----') # 显示多个周期的loss曲线 loss_avg_epoch = loss_epoch / (data_count * 1.0) # print(loss_avg_epoch) if args.vis: win = 'loss_epoch' loss_avg_epoch_expand = np.expand_dims(loss_avg_epoch, axis=0) win_res = vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='loss')) if args.vis: win = 'lr_epoch' lr_epoch = np.array(scheduler.get_lr()) lr_epoch_expand = np.expand_dims(lr_epoch, axis=0) win_res = vis.line(X=np.ones(1)*epoch, Y=lr_epoch_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*epoch, Y=lr_epoch_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='lr')) # ------------------train metris------------------------------- if args.vis: score, class_iou = scores(train_gts, train_preds, n_class=args.n_classes) for k, v in score.items(): print(k, v) if k == 'Overall Acc : \t': # 显示校准周期的mIoU overall_acc = v if args.vis: win = 'acc_epoch' overall_acc_expand = np.expand_dims(overall_acc, axis=0) win_res = vis.line(X=np.ones(1) * epoch, Y=overall_acc_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1) * epoch, Y=overall_acc_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='accuracy')) # clear for new training metrics train_gts, train_preds = [], [] # ------------------train metris------------------------------- if args.save_model and epoch%args.save_epoch==0: torch.save(model.state_dict(), '{}_{}_class_{}_{}.pt'.format(args.structure, args.dataset, args.n_classes, epoch))
def validate(args): init_time = str(int(time.time())) if args.vis: vis = visdom.Visdom() if args.dataset_path == '': HOME_PATH = os.path.expanduser('~') local_path = os.path.join(HOME_PATH, 'Data/CamVid') else: local_path = args.dataset_path local_path = os.path.expanduser(args.dataset_path) if args.dataset == 'CamVid': dst = camvidLoader(local_path, is_transform=True, split=args.dataset_type) elif args.dataset == 'CityScapes': dst = cityscapesLoader(local_path, is_transform=True) else: pass valloader = torch.utils.data.DataLoader(dst, batch_size=1) # if os.path.isfile(args.validate_model): if args.validate_model != '': model = torch.load(args.validate_model) else: if args.structure == 'fcn32s': model = fcn(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn16s': model = fcn(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn8s': model = fcn(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet18_32s': model = fcn_resnet18(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet18_16s': model = fcn_resnet18(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet18_8s': model = fcn_resnet18(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet34_32s': model = fcn_resnet34(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet34_16s': model = fcn_resnet34(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_resnet34_8s': model = fcn_resnet34(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_MobileNet_32s': model = fcn_MobileNet(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_MobileNet_16s': model = fcn_MobileNet(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn_MobileNet_8s': model = fcn_MobileNet(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ResNetDUC': model = ResNetDUC(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ResNetDUCHDC': model = ResNetDUCHDC(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet': model = segnet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet_vgg19': model = segnet_vgg19(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet_unet': model = segnet_unet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet_alignres': model = segnet_alignres(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'sqnet': model = sqnet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet_squeeze': model = segnet_squeeze(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ENet': model = ENet(n_classes=dst.n_classes) elif args.structure == 'ENetV2': model = ENetV2(n_classes=dst.n_classes) elif args.structure == 'drn_d_22': model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'drn_a_50': model = DRNSeg(model_name='drn_a_50', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'drn_a_18': model = DRNSeg(model_name='drn_a_18', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'drn_e_22': model = DRNSeg(model_name='drn_e_22', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'erfnet': model = erfnet(n_classes=dst.n_classes) elif args.structure == 'fcdensenet103': model = fcdensenet103(n_classes=dst.n_classes) elif args.structure == 'fcdensenet56': model = fcdensenet56(n_classes=dst.n_classes) elif args.structure == 'fcdensenet_tiny': model = fcdensenet_tiny(n_classes=dst.n_classes) elif args.structure == 'Res_Deeplab_101': model = Res_Deeplab_101(n_classes=dst.n_classes) elif args.structure == 'Res_Deeplab_50': model = Res_Deeplab_50(n_classes=dst.n_classes) elif args.structure == 'EDANet': model = EDANet(n_classes=dst.n_classes) elif args.structure == 'drn_a_asymmetric_18': model = DRNSeg(model_name='drn_a_asymmetric_18', n_classes=dst.n_classes, pretrained=False) if args.validate_model_state_dict != '': try: model.load_state_dict( torch.load(args.validate_model_state_dict)) except KeyError: print('missing key') if args.cuda: model.cuda() model.eval() gts, preds = [], [] for i, (imgs, labels) in enumerate(valloader): print(i) # print(labels.shape) # print(imgs.shape) # 将np变量转换为pytorch中的变量 imgs = Variable(imgs) labels = Variable(labels) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() outputs = model(imgs) # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width, # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label pred = outputs.cpu().data.max(1)[1].numpy() gt = labels.cpu().data.numpy() # print(pred.dtype) # print(gt.dtype) # print('pred.shape:', pred.shape) # print('gt.shape:', gt.shape) # if args.vis and i % 1 == 0: # img = imgs.cpu().data.numpy()[0] # # print(img.shape) # label_color = dst.decode_segmap(gt[0]).transpose(2, 0, 1) # # print(label_color.shape) # pred_label_color = dst.decode_segmap(pred[0]).transpose(2, 0, 1) # # print(pred_label_color.shape) # # try: # # win = 'label_color' # # vis.image(label_color, win=win) # # win = 'pred_label_color' # # vis.image(pred_label_color, win=win) # # except ConnectionError: # # print('ConnectionError') # # # if args.blend: # img_hwc = img.transpose(1, 2, 0) # img_hwc = img_hwc*255.0 # img_hwc += np.array([104.00699, 116.66877, 122.67892]) # img_hwc = np.array(img_hwc, dtype=np.uint8) # # label_color_hwc = label_color.transpose(1, 2, 0) # pred_label_color_hwc = pred_label_color.transpose(1, 2, 0) # pred_label_color_hwc = np.array(pred_label_color_hwc, dtype=np.uint8) # # print(img_hwc.dtype) # # print(pred_label_color_hwc.dtype) # label_blend = img_hwc * 0.5 + pred_label_color_hwc * 0.5 # label_blend = np.array(label_blend, dtype=np.uint8) # # if not os.path.exists('/tmp/' + init_time): # os.mkdir('/tmp/' + init_time) # time_str = str(int(time.time())) # # misc.imsave('/tmp/'+init_time+'/'+time_str+'_label_blend.png', label_blend) for gt_, pred_ in zip(gt, pred): gts.append(gt_) preds.append(pred_) score, class_iou = scores(gts, preds, n_class=dst.n_classes) for k, v in score.items(): print(k, v) for i in range(dst.n_classes): print(i, class_iou[i])
def train(args): def type_callback(event): # print('event_type:{}'.format(event['event_type'])) if event['event_type'] == 'KeyPress': event_key = event['key'] if event_key == 'Enter': pass # print('event_type:Enter') elif event_key == 'Backspace': pass # print('event_type:Backspace') elif event_key == 'Delete': pass # print('event_type:Delete') elif len(event_key) == 1: pass # print('event_key:{}'.format(event['key'])) if event_key=='s': import json win = 'loss_iteration' win_data = vis.get_window_data(win) win_data_dict = json.loads(win_data) win_data_content_dict = win_data_dict['content'] win_data_x = np.array(win_data_content_dict['data'][0]['x']) win_data_y = np.array(win_data_content_dict['data'][0]['y']) win_data_save_file = '/tmp/loss_iteration_{}.txt'.format(init_time) with open(win_data_save_file, 'wb') as f: for item_x, item_y in zip(win_data_x, win_data_y): f.write("{} {}\n".format(item_x, item_y)) done_time = str(int(time.time())) vis.text(vis_text_usage+'done at {}'.format(done_time), win=callback_text_usage_window) init_time = str(int(time.time())) if args.vis: vis = visdom.Visdom() vis_text_usage = 'Operating in the text window<br>Press s to save data<br>' callback_text_usage_window = vis.text(vis_text_usage) vis.register_event_handler(type_callback, callback_text_usage_window) # if args.dataset_path == '': # HOME_PATH = os.path.expanduser('~') # local_path = os.path.join(HOME_PATH, 'Data/CamVid') # else: local_path = os.path.expanduser(args.dataset_path) if args.dataset == 'CamVid': dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment) elif args.dataset == 'CityScapes': dst = cityscapesLoader(local_path, is_transform=True) else: pass # dst.n_classes = args.n_classes # 保证输入的class trainloader = torch.utils.data.DataLoader(dst, batch_size=args.batch_size, shuffle=True) start_epoch = 0 if args.resume_model != '': model = torch.load(args.resume_model) start_epoch_id1 = args.resume_model.rfind('_') start_epoch_id2 = args.resume_model.rfind('.') start_epoch = int(args.resume_model[start_epoch_id1+1:start_epoch_id2]) else: try: model = eval(args.structure)(n_classes=dst.n_classes, pretrained=args.init_vgg16) except: print('missing structure or not support') exit(0) if args.resume_model_state_dict != '': try: # fcn32s、fcn16s和fcn8s模型略有增加参数,互相赋值重新训练过程中会有KeyError,暂时捕捉异常处理 start_epoch_id1 = args.resume_model_state_dict.rfind('_') start_epoch_id2 = args.resume_model_state_dict.rfind('.') start_epoch = int(args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2]) pretrained_dict = torch.load(args.resume_model_state_dict) # model_dict = model.state_dict() # for k, v in pretrained_dict.items(): # print(k) # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) except KeyError: print('missing key') if args.cuda: model.cuda() model.train() print('start_epoch:', start_epoch) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4) # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4) data_count = int(dst.__len__() * 1.0 / args.batch_size) print('data_count:', data_count) for epoch in range(start_epoch+1, 20000, 1): loss_epoch = 0 loss_avg_epoch = 0 # data_count = 0 # if args.vis: # vis.text('epoch:{}'.format(epoch), win='epoch') for i, (imgs, labels) in enumerate(trainloader): # 最后的几张图片可能不到batch_size的数量,比如batch_size=4,可能只剩3张 imgs_batch = imgs.shape[0] if imgs_batch != args.batch_size: break print(i) # data_count = i # print(labels.shape) # print(imgs.shape) imgs = Variable(imgs) labels = Variable(labels) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() outputs = model(imgs) # print('type(outputs):', type(outputs)) if args.vis and i%50==0: pred_labels = outputs.cpu().data.max(1)[1].numpy() # print(pred_labels.shape) label_color = dst.decode_segmap(labels.cpu().data.numpy()[0]).transpose(2, 0, 1) # print(label_color.shape) pred_label_color = dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1) # print(pred_label_color.shape) win = 'label_color' vis.image(label_color, win=win) win = 'pred_label_color' vis.image(pred_label_color, win=win) # if epoch < 100: # if not os.path.exists('/tmp/'+init_time): # os.mkdir('/tmp/'+init_time) # time_str = str(int(time.time())) # print('label_color.transpose(2, 0, 1).shape:', label_color.transpose(1, 2, 0).shape) # print('pred_label_color.transpose(2, 0, 1).shape:', pred_label_color.transpose(1, 2, 0).shape) # cv2.imwrite('/tmp/'+init_time+'/'+time_str+'_label.png', label_color.transpose(1, 2, 0)) # cv2.imwrite('/tmp/'+init_time+'/'+time_str+'_pred_label.png', pred_label_color.transpose(1, 2, 0)) # print(outputs.size()) # print(labels.size()) # 一次backward后如果不清零,梯度是累加的 optimizer.zero_grad() loss = cross_entropy2d(outputs, labels) loss_np = loss.cpu().data.numpy() loss_epoch += loss_np print('loss:', loss_np) loss.backward() optimizer.step() # 显示一个周期的loss曲线 if args.vis: win = 'loss_iteration' loss_np_expand = np.expand_dims(loss_np, axis=0) win_res = vis.line(X=np.ones(1)*(i+data_count*(epoch-1)+1), Y=loss_np_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*(i+1), Y=loss_np_expand, win=win) # if i+data_count*(epoch-1)==0: # vis.register_event_handler(type_callback, win_res) # 关闭清空一个周期的loss,目标不清空 # if args.vis: # win = 'loss_iteration' # vis.close(win) # 显示多个周期的loss曲线 loss_avg_epoch = loss_epoch / (data_count * 1.0) # print(loss_avg_epoch) if args.vis: win = 'loss_epoch' loss_avg_epoch_expand = np.expand_dims(loss_avg_epoch, axis=0) win_res = vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch_expand, win=win) if args.save_model and epoch%args.save_epoch==0: torch.save(model.state_dict(), '{}_camvid_class_{}_{}.pt'.format(args.structure, dst.n_classes, epoch))
def validate(args): init_time = str(int(time.time())) if args.vis: # start visdom and close all window vis = visdom.Visdom() vis.close() # vis_text_usage = 'Operating in the text window<br>Press s to save data<br>' # callback_text_usage_window = vis.text(vis_text_usage) # vis.register_event_handler(type_callback, callback_text_usage_window) class_weight = None local_path = os.path.expanduser(args.dataset_path) train_dst = None val_dst = None if args.dataset == 'CamVid': train_dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment, split='train') val_dst = camvidLoader(local_path, is_transform=True, is_augment=False, split='val') elif args.dataset == 'CityScapes': train_dst = cityscapesLoader(local_path, is_transform=True, split='train') val_dst = cityscapesLoader(local_path, is_transform=True, split='val') else: print('{} dataset does not implement'.format(args.dataset)) exit(0) if args.cuda: if class_weight is not None: class_weight = class_weight.cuda() print('class_weight:', class_weight) train_loader = torch.utils.data.DataLoader(train_dst, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_dst, batch_size=1, shuffle=True) yolo_B = 2 yolo_C = 2 yolo_S = 7 yolo_out_tensor_shape = yolo_B * 5 + yolo_C print('yolo_out_tensor_shape:', yolo_out_tensor_shape) det_file_root = os.path.expanduser('~/Data/CamVid/train/') det_train_dst = yoloDataset(root=det_file_root, list_file=['camvid_det.txt'], train=False, transform=[transforms.ToTensor()], yolo_out_tensor_shape=yolo_out_tensor_shape) det_train_loader = torch.utils.data.DataLoader(det_train_dst, batch_size=1, shuffle=False) model = drnsegmt_a_18(pretrained=args.init_vgg16, n_classes=args.n_classes, det_tensor_num=yolo_out_tensor_shape) if args.resume_model_state_dict != '': pretrained_dict = torch.load(args.resume_model_state_dict, map_location='cpu') model.load_state_dict(pretrained_dict) else: print('missing resume_model_state_dict') exit() if args.cuda: model.cuda() model.eval() for epoch in range(0, 1, 1): # ----for object detection---- for det_i, (det_imgs, det_labels, det_imgs_ori) in enumerate(det_train_loader): print('det_imgs.shape:', det_imgs.shape) print('det_labels.shape:', det_labels.shape) # det_imgs_height = det_imgs.shape[2] # det_imgs_width = det_imgs.shape[3] # print('det_imgs_height:', det_imgs_height) # print('det_imgs_width:', det_imgs_width) det_imgs = Variable(det_imgs) det_labels = Variable(det_labels) if args.cuda: det_imgs = det_imgs.cuda() det_labels = det_labels.cuda() _, outputs_det = model(det_imgs) # print('outpust_det:', outputs_det.shape) # det_loss = det_criterion(outputs_det, det_labels) # det_loss_np = det_loss.cpu().data.numpy() outputs_det = outputs_det.cpu() det_boxes, det_cls_indexs, det_probs = decoder(outputs_det) image_ori = det_imgs_ori[0, ...].cpu().data.numpy() det_imgs_ori_height = image_ori.shape[0] det_imgs_ori_width = image_ori.shape[1] # image = image.transpose(1, 2, 0) for i, det_box in enumerate(det_boxes): x1 = int(det_box[0] * det_imgs_ori_width) x2 = int(det_box[2] * det_imgs_ori_width) y1 = int(det_box[1] * det_imgs_ori_height) y2 = int(det_box[3] * det_imgs_ori_height) det_cls_index = det_cls_indexs[i] det_cls_index = int(det_cls_index) # convert LongTensor to int det_prob = det_probs[i] det_prob = float(det_prob) if x1 < 0 or x1 > det_imgs_ori_width - 1: continue if x2 < 0 or x2 > det_imgs_ori_width - 1: continue if y1 < 0 or y1 > det_imgs_ori_height - 1: continue if y2 < 0 or y2 > det_imgs_ori_height - 1: continue # x1 = np.clip(x1, 0, det_imgs_ori_width-1) # x2 = np.clip(x2, 0, det_imgs_ori_width-1) # y1 = np.clip(y1, 0, det_imgs_ori_height-1) # y2 = np.clip(y2, 0, det_imgs_ori_height-1) if det_prob > 0: print('(x1,y1)->(x2,y2):({},{})->({},{})'.format( x1, y1, x2, y2)) cv2.rectangle(image_ori, (x1, y1), (x2, y2), (0, 0, 255)) cv2.imshow('image_ori', image_ori) cv2.waitKey()
def validate(args): if args.vis: vis = visdom.Visdom() if args.dataset_path == '': HOME_PATH = os.path.expanduser('~') local_path = os.path.join(HOME_PATH, 'Data/CamVid') else: local_path = args.dataset_path dst = camvidLoader(local_path, is_transform=True, split='val') valloader = torch.utils.data.DataLoader(dst, batch_size=1) # if os.path.isfile(args.validate_model): if args.validate_model != '': model = torch.load(args.validate_model) else: if args.structure == 'fcn32s': model = fcn(module_type='32s', n_classes=dst.n_classes) elif args.structure == 'fcn16s': model = fcn(module_type='16s', n_classes=dst.n_classes) elif args.structure == 'fcn8s': model = fcn(module_type='8s', n_classes=dst.n_classes) elif args.structure == 'ResNetDUC': model = ResNetDUC(n_classes=dst.n_classes) elif args.structure == 'segnet': model = segnet(n_classes=dst.n_classes) elif args.structure == 'ENet': model = ENet(n_classes=dst.n_classes) elif args.structure == 'drn_d_22': model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes) elif args.structure == 'pspnet': model = pspnet(n_classes=dst.n_classes, use_aux=False) elif args.structure == 'erfnet': model = erfnet(n_classes=dst.n_classes) if args.validate_model_state_dict != '': try: model.load_state_dict(torch.load(args.validate_model_state_dict)) except KeyError: print('missing key') model.eval() gts, preds = [], [] for i, (imgs, labels) in enumerate(valloader): print(i) # print(labels.shape) # print(imgs.shape) # 将np变量转换为pytorch中的变量 imgs = Variable(imgs) labels = Variable(labels) outputs = model(imgs) # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width, # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label pred = outputs.data.max(1)[1].numpy() gt = labels.data.numpy() # print(pred.dtype) # print(gt.dtype) # print('pred.shape:', pred.shape) # print('gt.shape:', gt.shape) if args.vis and i % 50 == 0: img = imgs.data.numpy()[0] # print(img.shape) label_color = dst.decode_segmap(gt[0]).transpose(2, 0, 1) # print(label_color.shape) pred_label_color = dst.decode_segmap(pred[0]).transpose(2, 0, 1) # print(pred_label_color.shape) # try: # win = 'label_color' # vis.image(label_color, win=win) # win = 'pred_label_color' # vis.image(pred_label_color, win=win) # except ConnectionError: # print('ConnectionError') if args.blend: img_hwc = img.transpose(1, 2, 0) img_hwc = img_hwc*255.0 img_hwc += np.array([104.00699, 116.66877, 122.67892]) img_hwc = np.array(img_hwc, dtype=np.uint8) # label_color_hwc = label_color.transpose(1, 2, 0) pred_label_color_hwc = pred_label_color.transpose(1, 2, 0) pred_label_color_hwc = np.array(pred_label_color_hwc, dtype=np.uint8) # print(img_hwc.dtype) # print(pred_label_color_hwc.dtype) label_blend = img_hwc * 0.5 + pred_label_color_hwc * 0.5 label_blend = np.array(label_blend, dtype=np.uint8) misc.imsave('/tmp/label_blend.png', label_blend) for gt_, pred_ in zip(gt, pred): gts.append(gt_) preds.append(pred_) score, class_iou = scores(gts, preds, n_class=dst.n_classes) for k, v in score.items(): print(k, v) for i in range(dst.n_classes): print(i, class_iou[i])
def train(args): init_time = str(int(time.time())) if args.vis: vis = visdom.Visdom() if args.dataset_path == '': HOME_PATH = os.path.expanduser('~') local_path = os.path.join(HOME_PATH, 'Data/CamVid') else: local_path = args.dataset_path dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment) dst.n_classes = args.n_classes # 保证输入的class trainloader = torch.utils.data.DataLoader(dst, batch_size=args.batch_size, shuffle=True) start_epoch = 0 if args.resume_model != '': model = torch.load(args.resume_model) start_epoch_id1 = args.resume_model.rfind('_') start_epoch_id2 = args.resume_model.rfind('.') start_epoch = int(args.resume_model[start_epoch_id1+1:start_epoch_id2]) else: if args.structure == 'fcn32s': model = fcn(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn16s': model = fcn(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'fcn8s': model = fcn(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ResNetDUC': model = ResNetDUC(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'segnet': model = segnet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'ENet': model = ENet(n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'drn_d_22': model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes, pretrained=args.init_vgg16) elif args.structure == 'pspnet': model = pspnet(n_classes=dst.n_classes, pretrained=args.init_vgg16, use_aux=False) elif args.structure == 'erfnet': model = erfnet(n_classes=dst.n_classes) if args.resume_model_state_dict != '': try: # fcn32s、fcn16s和fcn8s模型略有增加参数,互相赋值重新训练过程中会有KeyError,暂时捕捉异常处理 start_epoch_id1 = args.resume_model_state_dict.rfind('_') start_epoch_id2 = args.resume_model_state_dict.rfind('.') start_epoch = int(args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2]) pretrained_dict = torch.load(args.resume_model_state_dict) # model_dict = model.state_dict() # for k, v in pretrained_dict.items(): # print(k) # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) except KeyError: print('missing key') if args.cuda: model.cuda() print('start_epoch:', start_epoch) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4) for epoch in range(start_epoch+1, 20000, 1): loss_epoch = 0 loss_avg_epoch = 0 data_count = 0 # if args.vis: # vis.text('epoch:{}'.format(epoch), win='epoch') for i, (imgs, labels) in enumerate(trainloader): print(i) data_count = i # print(labels.shape) # print(imgs.shape) imgs = Variable(imgs) labels = Variable(labels) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() outputs = model(imgs) if args.vis and i%50==0: pred_labels = outputs.data.max(1)[1].numpy() # print(pred_labels.shape) label_color = dst.decode_segmap(labels.data.numpy()[0]).transpose(2, 0, 1) # print(label_color.shape) pred_label_color = dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1) # print(pred_label_color.shape) win = 'label_color' vis.image(label_color, win=win) win = 'pred_label_color' vis.image(pred_label_color, win=win) if epoch < 100: if not os.path.exists('/tmp/'+init_time): os.mkdir('/tmp/'+init_time) time_str = str(int(time.time())) print('label_color.transpose(2, 0, 1).shape:', label_color.transpose(1, 2, 0).shape) print('pred_label_color.transpose(2, 0, 1).shape:', pred_label_color.transpose(1, 2, 0).shape) cv2.imwrite('/tmp/'+init_time+'/'+time_str+'_label.png', label_color.transpose(1, 2, 0)) cv2.imwrite('/tmp/'+init_time+'/'+time_str+'_pred_label.png', pred_label_color.transpose(1, 2, 0)) # print(outputs.size()) # print(labels.size()) # 一次backward后如果不清零,梯度是累加的 optimizer.zero_grad() loss = cross_entropy2d(outputs, labels) loss_numpy = loss.cpu().data.numpy() loss_epoch += loss_numpy print('loss:', loss_numpy) loss.backward() optimizer.step() # 显示一个周期的loss曲线 if args.vis: win = 'loss' win_res = vis.line(X=np.ones(1)*i, Y=loss.data.numpy(), win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*i, Y=loss.data.numpy(), win=win) # 关闭清空一个周期的loss if args.vis: win = 'loss' vis.close(win) # 显示多个周期的loss曲线 loss_avg_epoch = loss_epoch / (data_count * 1.0) # print(loss_avg_epoch) if args.vis: win = 'loss_epoch' win_res = vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch, win=win, update='append') if win_res != win: vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch, win=win) if args.save_model and epoch%args.save_epoch==0: torch.save(model.state_dict(), '{}_camvid_class_{}_{}.pt'.format(args.structure, dst.n_classes, epoch))
def performance_table(args): local_path = os.path.expanduser(args.dataset_path) if args.dataset == 'CamVid': dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment) elif args.dataset == 'CityScapes': dst = cityscapesLoader(local_path, is_transform=True) else: pass # dst.n_classes = args.n_classes # 保证输入的class trainloader = torch.utils.data.DataLoader(dst, batch_size=args.batch_size, shuffle=True) start_epoch = 0 if args.resume_model != '': model = torch.load(args.resume_model) start_epoch_id1 = args.resume_model.rfind('_') start_epoch_id2 = args.resume_model.rfind('.') start_epoch = int(args.resume_model[start_epoch_id1 + 1:start_epoch_id2]) else: if args.structure == 'drnseg_a_asymmetric_n': model = eval(args.structure)(n_classes=dst.n_classes, pretrained=args.init_vgg16, depth_n=args.depth_n) elif args.structure == 'drnseg_a_n': model = eval(args.structure)(n_classes=dst.n_classes, pretrained=args.init_vgg16, depth_n=args.depth_n) else: model = eval(args.structure)(n_classes=dst.n_classes, pretrained=args.init_vgg16) if args.resume_model_state_dict != '': try: # fcn32s、fcn16s和fcn8s模型略有增加参数,互相赋值重新训练过程中会有KeyError,暂时捕捉异常处理 start_epoch_id1 = args.resume_model_state_dict.rfind('_') start_epoch_id2 = args.resume_model_state_dict.rfind('.') start_epoch = int( args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2]) pretrained_dict = torch.load(args.resume_model_state_dict) # model_dict = model.state_dict() # for k, v in pretrained_dict.items(): # print(k) # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) except KeyError: print('missing key') model = add_flops_counting_methods(model) if args.cuda: model.cuda() model.train() model.start_flops_count() # print('start_epoch:', start_epoch) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4) # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4) forward_time = 0 backward_time = 0 # 第一次warmup将GPU调用 for i, (imgs, labels) in enumerate(trainloader): imgs_batch = imgs.shape[0] if imgs_batch != args.batch_size: break if args.cuda: imgs = imgs.cuda() model(imgs) break for epoch in range(0, 1, 1): for i, (imgs, labels) in enumerate(trainloader): # 最后的几张图片可能不到batch_size的数量,比如batch_size=4,可能只剩3张 imgs_batch = imgs.shape[0] if imgs_batch != args.batch_size: break # print(i) # data_count = i # print(labels.shape) # print(imgs.shape) imgs = Variable(imgs) labels = Variable(labels) # imgs = Variable(torch.randn(1, 3, 360, 640)) # labels = Variable(torch.LongTensor(np.ones((1, 360, 640), dtype=np.int))) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() if args.cuda: torch.cuda.synchronize() start = time.time() outputs = model(imgs) if args.cuda: torch.cuda.synchronize() end = time.time() forward_time += (end - start) # print('forward time:', end - start) if args.cuda: torch.cuda.synchronize() start = time.time() # 一次backward后如果不清零,梯度是累加的 optimizer.zero_grad() loss = cross_entropy2d(outputs, labels) loss.backward() optimizer.step() if args.cuda: torch.cuda.synchronize() end = time.time() backward_time += (end - start) # print('backward time:', end - start) if i == args.iterations - 1: break avg_forward_time = forward_time * 1.0 / args.iterations avg_backward_time = backward_time * 1.0 / args.iterations print('average forward time:', forward_time * 1.0 / args.iterations) print('average backward time:', backward_time * 1.0 / args.iterations) model_flops = model.compute_average_flops_cost() / 1e9 / 2 print('model_flops:', model_flops) if args.save_model: torch.save( model.state_dict(), 'performance_{}_{}_class_{}.pt'.format(args.structure, args.dataset, args.n_classes)) return avg_forward_time, avg_backward_time, model_flops
def train(args): now = datetime.datetime.now() now_str = '{}-{}-{} {}:{}:{}'.format(now.year, now.month, now.day, now.hour, now.minute, now.second) # print('now:', now) # print('now_str:', now_str) if args.vis: # start visdom and close all window vis = visdom.Visdom(env=now_str) vis.close() class_weight = None local_path = os.path.expanduser(args.dataset_path) train_dst = None val_dst = None if args.dataset == 'CamVid': train_dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment, split='train') val_dst = camvidLoader(local_path, is_transform=True, is_augment=False, split='val') trainannot_image_dir = os.path.expanduser( os.path.join(local_path, "trainannot")) trainannot_image_files = [ os.path.join(trainannot_image_dir, file) for file in os.listdir(trainannot_image_dir) if file.endswith('.png') ] if args.class_weighting == 'MFB': class_weight = median_frequency_balancing(trainannot_image_files, num_classes=12) class_weight = torch.tensor(class_weight) elif args.class_weighting == 'ENET': class_weight = ENet_weighing(trainannot_image_files, num_classes=12) class_weight = torch.tensor(class_weight) elif args.dataset == 'CityScapes': train_dst = cityscapesLoader(local_path, is_transform=True, split='train') val_dst = cityscapesLoader(local_path, is_transform=True, split='val') elif args.dataset == 'SegmPred': train_dst = segmpredLoader(local_path, is_transform=True, split='train') val_dst = segmpredLoader(local_path, is_transform=True, split='train') elif args.dataset == 'MovingMNIST': # class_weight = [0.1, 0.5] # class_weight = torch.tensor(class_weight) train_dst = movingmnistLoader(local_path, is_transform=True, split='train') val_dst = movingmnistLoader(local_path, is_transform=True, split='val') elif args.dataset == 'FreeSpace': train_dst = freespaceLoader(local_path, is_transform=True, split='train') val_dst = freespaceLoader(local_path, is_transform=True, split='val') else: print('{} dataset does not implement'.format(args.dataset)) exit(0) if args.cuda: if class_weight is not None: class_weight = class_weight.cuda() print('class_weight:', class_weight) train_loader = torch.utils.data.DataLoader(train_dst, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_dst, batch_size=1, shuffle=True) start_epoch = 0 best_mIoU = 0 if args.resume_model != '': model = torch.load(args.resume_model) start_epoch_id1 = args.resume_model.rfind('_') start_epoch_id2 = args.resume_model.rfind('.') start_epoch = int(args.resume_model[start_epoch_id1 + 1:start_epoch_id2]) else: # model = eval(args.structure)(n_classes=args.n_classes, pretrained=args.init_vgg16) try: model = eval(args.structure)(n_classes=args.n_classes, pretrained=args.init_vgg16) except: print('missing structure or not support') exit(0) # ---------------for testing SegmPred--------------- if args.dataset == 'MovingMNIST': input_channel = 1 * 9 elif args.dataset == 'SegmPred': input_channel = 19 * 4 if args.structure == 'drnseg_a_18': model = drnseg_a_18(n_classes=args.n_classes, pretrained=args.init_vgg16, input_channel=input_channel) # ---------------for testing SegmPred--------------- if args.resume_model_state_dict != '': try: # from model save format get useful information, such as miou, epoch miou_model_name_str = '_miou_' class_model_name_str = '_class_' miou_id1 = args.resume_model_state_dict.find( miou_model_name_str) + len(miou_model_name_str) miou_id2 = args.resume_model_state_dict.find( class_model_name_str) best_mIoU = float( args.resume_model_state_dict[miou_id1:miou_id2]) start_epoch_id1 = args.resume_model_state_dict.rfind('_') start_epoch_id2 = args.resume_model_state_dict.rfind('.') start_epoch = int( args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2]) pretrained_dict = torch.load(args.resume_model_state_dict, map_location='cpu') model.load_state_dict(pretrained_dict) except KeyError: print('missing resume_model_state_dict or wrong type') if args.cuda: model.cuda() print('start_epoch:', start_epoch) print('best_mIoU:', best_mIoU) if args.solver == 'SGD': optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4) elif args.solver == 'RMSprop': optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4) elif args.solver == 'Adam': optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=5e-4) else: print('missing solver or not support') exit(0) # when observerd object dose not decrease scheduler will let the optimizer learing rate decrease # scheduler = ReduceLROnPlateau(optimizer, 'min', patience=100, min_lr=1e-10, verbose=True) if args.lr_policy == 'Constant': scheduler = ConstantLR(optimizer) elif args.lr_policy == 'Polynomial': scheduler = PolynomialLR( optimizer, max_iter=args.training_epoch, power=0.9) # base lr=0.01 power=0.9 like PSPNet elif args.lr_policy == 'MultiStep': scheduler = MultiStepLR( optimizer, milestones=[10, 50, 90], gamma=0.1) # base lr=0.01 power=0.9 like PSPNet # scheduler = StepLR(optimizer, step_size=1, gamma=0.1) data_count = int(train_dst.__len__() * 1.0 / args.batch_size) print('data_count:', data_count) # iteration_step = 0 train_gts, train_preds = [], [] for epoch in range(start_epoch + 1, args.training_epoch, 1): loss_epoch = 0 scheduler.step() optimizer.zero_grad( ) # when train next time zero all grad, just acc the grad when the epoch training for i, (imgs, labels) in enumerate(train_loader): # if i==1: # break model.train() # 最后的几张图片可能不到batch_size的数量,比如batch_size=4,可能只剩3张 imgs_batch = imgs.shape[0] if imgs_batch != args.batch_size: break # iteration_step += 1 imgs = Variable(imgs) labels = Variable(labels) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() outputs = model(imgs) # print('imgs.size:', imgs.size()) # print('labels.size:', labels.size()) # print('outputs.size:', outputs.size()) loss = cross_entropy2d(outputs, labels, weight=class_weight) # add grad backward the avg loss loss_grad_acc_avg = loss * 1.0 / args.grad_acc_steps loss_grad_acc_avg.backward() loss_np = loss.cpu().data.numpy() loss_epoch += loss_np if (i + 1) % args.grad_acc_steps == 0: optimizer.step() # 一次backward后如果不清零,梯度是累加的 optimizer.zero_grad() # ------------------train metris------------------------------- train_pred = outputs.cpu().data.max(1)[1].numpy() train_gt = labels.cpu().data.numpy() for train_gt_, train_pred_ in zip(train_gt, train_pred): train_gts.append(train_gt_) train_preds.append(train_pred_) # ------------------train metris------------------------------- if args.vis and i % 50 == 0: pred_labels = outputs.cpu().data.max(1)[1].numpy() label_color = train_dst.decode_segmap( labels.cpu().data.numpy()[0]).transpose(2, 0, 1) pred_label_color = train_dst.decode_segmap( pred_labels[0]).transpose(2, 0, 1) win = 'label_color' vis.image(label_color, win=win, opts=dict(title='Gt', caption='Ground Truth')) win = 'pred_label_color' vis.image(pred_label_color, win=win, opts=dict(title='Pred', caption='Prediction')) # 显示一个周期的loss曲线 if args.vis: win = 'loss_iteration' loss_np_expand = np.expand_dims(loss_np, axis=0) win_res = vis.line(X=np.ones(1) * (i + data_count * (epoch - 1) + 1), Y=loss_np_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1) * (i + data_count * (epoch - 1) + 1), Y=loss_np_expand, win=win, opts=dict(title=win, xlabel='iteration', ylabel='loss')) # val result on val dataset and pick best to save if args.val_interval > 0 and epoch % args.val_interval == 0: print('----starting val----') model.eval() val_gts, val_preds = [], [] for val_i, (val_imgs, val_labels) in enumerate(val_loader): # print(val_i) val_imgs = Variable(val_imgs, volatile=True) val_labels = Variable(val_labels, volatile=True) if args.cuda: val_imgs = val_imgs.cuda() val_labels = val_labels.cuda() val_outputs = model(val_imgs) val_pred = val_outputs.cpu().data.max(1)[1].numpy() val_gt = val_labels.cpu().data.numpy() for val_gt_, val_pred_ in zip(val_gt, val_pred): val_gts.append(val_gt_) val_preds.append(val_pred_) score, class_iou = scores(val_gts, val_preds, n_class=args.n_classes) for k, v in score.items(): print(k, v) if k == 'Mean IoU : \t': v_iou = v if v > best_mIoU: best_mIoU = v_iou torch.save( model.state_dict(), '{}_{}_miou_{}_class_{}_{}.pt'.format( args.structure, args.dataset, best_mIoU, args.n_classes, epoch)) # 显示校准周期的mIoU if args.vis: win = 'mIoU_epoch' v_iou_expand = np.expand_dims(v_iou, axis=0) win_res = vis.line(X=np.ones(1) * epoch * args.val_interval, Y=v_iou_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1) * epoch * args.val_interval, Y=v_iou_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='mIoU')) for class_i in range(args.n_classes): print(class_i, class_iou[class_i]) print('----ending val----') # 显示多个周期的loss曲线 loss_avg_epoch = loss_epoch / (data_count * 1.0) # print(loss_avg_epoch) if args.vis: win = 'loss_epoch' loss_avg_epoch_expand = np.expand_dims(loss_avg_epoch, axis=0) win_res = vis.line(X=np.ones(1) * epoch, Y=loss_avg_epoch_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1) * epoch, Y=loss_avg_epoch_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='loss')) if args.vis: win = 'lr_epoch' lr_epoch = np.array(scheduler.get_lr()) lr_epoch_expand = np.expand_dims(lr_epoch, axis=0) win_res = vis.line(X=np.ones(1) * epoch, Y=lr_epoch_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1) * epoch, Y=lr_epoch_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='lr')) # ------------------train metris------------------------------- if args.vis: score, class_iou = scores(train_gts, train_preds, n_class=args.n_classes) for k, v in score.items(): print(k, v) if k == 'Overall Acc : \t': # 显示校准周期的mIoU overall_acc = v if args.vis: win = 'acc_epoch' overall_acc_expand = np.expand_dims(overall_acc, axis=0) win_res = vis.line(X=np.ones(1) * epoch, Y=overall_acc_expand, win=win, update='append') if win_res != win: vis.line(X=np.ones(1) * epoch, Y=overall_acc_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='accuracy')) # clear for new training metrics train_gts, train_preds = [], [] # ------------------train metris------------------------------- if args.save_model and epoch % args.save_epoch == 0: torch.save( model.state_dict(), '{}_{}_class_{}_{}.pt'.format(args.structure, args.dataset, args.n_classes, epoch))
batch_size = 1 dst = ImageFolderLMDB(lmdb_path, None) loader = DataLoader(dst, batch_size=batch_size, drop_last=True) time_start = time.time() for idx, data in enumerate(loader): pass # print("idx:", idx) time_end = time.time() print('load {} images cost time: {} sec'.format(len(dst), time_end - time_start)) print('load {} images {} fps'.format( len(dst), len(dst) * 1.0 / (time_end - time_start))) local_path = os.path.join(os.path.expanduser('~/Data/CamVid')) dst = camvidLoader(local_path, is_transform=False, is_augment=False) loader = DataLoader(dst, batch_size=batch_size) time_start = time.time() for idx, data in enumerate(loader): pass # print("idx:", idx) time_end = time.time() print('load {} images cost time: {} sec'.format(len(dst), time_end - time_start)) print('load {} images {} fps'.format( len(dst), len(dst) * 1.0 / (time_end - time_start)))
def validate(args): init_time = str(int(time.time())) if args.vis: vis = visdom.Visdom() local_path = os.path.expanduser(args.dataset_path) if args.dataset == 'CamVid': dst = camvidLoader(local_path, is_transform=True, split=args.dataset_type) elif args.dataset == 'CityScapes': dst = cityscapesLoader(local_path, is_transform=True, split=args.dataset_type) elif args.dataset == 'SegmPred': dst = segmpredLoader(local_path, is_transform=True, split=args.dataset_type) elif args.dataset == 'MovingMNIST': dst = movingmnistLoader(local_path, is_transform=True, split=args.dataset_type) elif args.dataset == 'FreeSpace': dst = freespaceLoader(local_path, is_transform=True, split=args.dataset_type) else: pass val_loader = torch.utils.data.DataLoader(dst, batch_size=1, shuffle=False) # if os.path.isfile(args.validate_model): if args.validate_model != '': model = torch.load(args.validate_model) else: # ---------------for testing SegmPred--------------- try: model = drnsegpred_a_18(n_classes=args.n_classes, pretrained=args.init_vgg16, input_shape=dst.input_shape) except: print('missing structure or not support') exit(0) if args.validate_model_state_dict != '': try: model.load_state_dict( torch.load(args.validate_model_state_dict, map_location='cpu')) except KeyError: print('missing key') # ---------------for testing SegmPred--------------- if args.cuda: model.cuda() # some model load different mode different performance model.eval() # model.train() gts, preds, errors, imgs_name = [], [], [], [] for i, (imgs, labels) in enumerate(val_loader): print(i) # if i==1: # break img_path = dst.files[args.dataset_type][i] img_name = img_path[img_path.rfind('/') + 1:] imgs_name.append(img_name) # print('img_path:', img_path) # print('img_name:', img_name) # print(labels.shape) # print(imgs.shape) # 将np变量转换为pytorch中的变量 imgs = Variable(imgs, volatile=True) labels = Variable(labels, volatile=True) if args.cuda: imgs = imgs.cuda() labels = labels.cuda() # print('imgs.shape', imgs.shape) # print('labels.shape', labels.shape) outputs = model(imgs) # print('outputs.shape', outputs.shape) loss = cross_entropy2d(outputs, labels) loss_np = loss.cpu().data.numpy() loss_np_float = float(loss_np) # print('loss_np_float:', loss_np_float) errors.append(loss_np_float) # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width, # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label pred = outputs.cpu().data.max(1)[1].numpy() gt = labels.cpu().data.numpy() if args.save_result: if not os.path.exists('/tmp/' + init_time): os.mkdir('/tmp/' + init_time) pred_labels = outputs.cpu().data.max(1)[1].numpy() # print('pred_labels.shape:', pred_labels.shape) label_color = dst.decode_segmap( labels.cpu().data.numpy()[0]).transpose(2, 0, 1) pred_label_color = dst.decode_segmap(pred_labels[0]).transpose( 2, 0, 1) # print('label_color.shape:', label_color.shape) # print('pred_label_color.shape:', pred_label_color.shape) label_color_cv2 = label_color.transpose(1, 2, 0) label_color_cv2 = cv2.cvtColor(label_color_cv2, cv2.COLOR_RGB2BGR) # print('label_color_cv2.shape:', label_color_cv2.shape) # print('label_color_cv2.dtype:', label_color_cv2.dtype) # cv2.imshow('label_color_cv2', label_color_cv2) # cv2.waitKey() cv2.imwrite('/tmp/' + init_time + '/gt_{}.png'.format(img_name), label_color_cv2) pred_label_color_cv2 = pred_label_color.transpose(1, 2, 0) pred_label_color_cv2 = cv2.cvtColor(pred_label_color_cv2, cv2.COLOR_RGB2BGR) cv2.imwrite('/tmp/' + init_time + '/pred_{}.png'.format(img_name), pred_label_color_cv2) for gt_, pred_ in zip(gt, pred): gts.append(gt_) preds.append(pred_) # print('errors:', errors) # print('imgs_name:', imgs_name) errors_indices = np.argsort(errors).tolist() print('errors_indices:', errors_indices) # for top_i in range(len(errors_indices)): for top_i in range(10): top_index = errors_indices.index(top_i) # print('top_index:', top_index) img_name_top = imgs_name[top_index] print('img_name_top:', img_name_top) score, class_iou = scores(gts, preds, n_class=dst.n_classes) for k, v in score.items(): print(k, v) class_iou_list = [] for i in range(dst.n_classes): class_iou_list.append(round(class_iou[i], 2)) # print(i, round(class_iou[i], 2)) print('classes:', range(dst.n_classes)) print('class_iou_list:', class_iou_list)