def validate(val_loader, model, criterion, eval_score=None, print_freq=10): # miou part >>> confusion_labels = np.arange(0, 19) confusion_matrix = RunningConfusionMatrix(confusion_labels) # miou part <<< batch_time = AverageMeter() losses = AverageMeter() score = AverageMeter() # switch to evaluate mode model.eval() end = time.time() for i, (input, target) in enumerate(val_loader): if type(criterion) in [ torch.nn.modules.loss.L1Loss, torch.nn.modules.loss.MSELoss ]: target = target.float() input = input.cuda() target = target.cuda(async=True) input_var = torch.autograd.Variable(input, volatile=True) target_var = torch.autograd.Variable(target, volatile=True) # compute output output = model(input_var) loss = criterion(output, target_var) confusion_matrix.update_matrix(target, output) # measure accuracy and record loss # prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.data[0], input.size(0)) if eval_score is not None: score.update(eval_score(output, target_var), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Score {score.val:.3f} ({score.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, score=score), flush=True) miou, top_1, top_5 = confusion_matrix.compute_current_mean_intersection_over_union( ) print(' * Score {top1.avg:.3f}'.format(top1=score)) print(' * mIoU {top1:.3f}'.format(top1=miou)) confusion_matrix.show_classes() return miou
def test(eval_data_loader, model, num_classes, output_dir='pred', has_gt=True, save_vis=False): model.eval() confusion_labels = np.arange(0, 19) confusion_matrix = RunningConfusionMatrix(confusion_labels) batch_time = AverageMeter() data_time = AverageMeter() end = time.time() hist = np.zeros((num_classes, num_classes)) for iter, (image, label, name, size) in enumerate(eval_data_loader): data_time.update(time.time() - end) image_var = Variable(image, requires_grad=False, volatile=True) final = model(image_var)[0] _, pred = torch.max(final, 1) pred = pred.cpu().data.numpy() batch_time.update(time.time() - end) prob = torch.exp(final) if save_vis: save_output_images(pred, name, output_dir, size) if prob.size(1) == 2: save_prob_images(prob, name, output_dir + '_prob', size) else: save_colorful_images(pred, name, output_dir + '_color', CITYSCAPE_PALLETE, size) if has_gt: # confusion_matrix.update_matrix(label, final) label = label.numpy() hist += fast_hist(pred.flatten(), label.flatten(), num_classes) # print('===> mAP {mAP:.3f}'.format( # mAP=round(np.nanmean(per_class_iu(hist)) * 100, 2))) end = time.time() print('Eval: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format( iter, len(eval_data_loader), batch_time=batch_time, data_time=data_time)) ious = per_class_iu(hist) * 100 if has_gt: # val return round(np.nanmean(ious), 2)
def train_seg(args): batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size checkpoint_dir = args.checkpoint_dir print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) pretrained_base = args.pretrained_base # print(dla_up.__dict__.get(args.arch)) single_model = dla_up.__dict__.get(args.arch)(classes=args.classes, down_ratio=args.down) single_model = convert_model(single_model) model = torch.nn.DataParallel(single_model).cuda() print('model_created') if args.edge_weight > 0: weight = torch.from_numpy( np.array([1, args.edge_weight], dtype=np.float32)) # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) criterion = nn.NLLLoss2d(ignore_index=-1, weight=weight) else: # criterion = nn.NLLLoss2d(ignore_index=255) criterion = nn.NLLLoss2d(ignore_index=-1) criterion.cuda() t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.append(transforms.RandomCrop(crop_size)) #TODO if args.random_color: t.append(transforms.RandomJitter(0.4, 0.4, 0.4)) t.extend([transforms.RandomHorizontalFlip()]) #TODO t_val = [] t_val.append(transforms.RandomCrop(crop_size)) dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/image_02/' dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/' + args.target + '/' my_train = BasicDataset(dir_img, dir_mask, transforms.Compose(t), is_train=True) val_dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/image_02/' val_dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/' + args.target + '/' my_val = BasicDataset(val_dir_img, val_dir_mask, transforms.Compose(t_val), is_train=True) train_loader = torch.utils.data.DataLoader(my_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( my_val, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) #TODO batch_size print("loader created") optimizer = torch.optim.SGD(single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = None #TODO cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) confusion_labels = np.arange(0, 5) val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) if args.evaluate: confusion_labels = np.arange(0, 2) val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1, reduce=True) validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) return writer = SummaryWriter(comment=args.log) # TODO test val # print("test val") # prec1 = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) for epoch in range(start_epoch, args.epochs): train_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) lr = adjust_learning_rate(args, optimizer, epoch) print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, lr_scheduler, confusion_matrix=train_confusion_matrix, writer=writer) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict() }, is_best=False, filename=checkpoint_path) # evaluate on validation set val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) prec1, loss_val = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) writer.add_scalar('mIoU/epoch', prec1, epoch + 1) writer.add_scalar('loss/epoch', loss_val, epoch + 1) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path) writer.close()