コード例 #1
0
ファイル: train.py プロジェクト: zacario-li/Fast-SCNN_pytorch
def prepareDataset(rootpath, trainlist, vallist, mean, std):
    # prepare dataset transform before training
    trans = transform.Compose([
        transform.RandScale([0.5,2]),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop(inputHW,crop_type='rand',padding=mean, ignore_label=255),
        transform.ToTensor(),
        transform.Normalize(mean=mean,std=std)
    ])
    # val transform
    valTrans = transform.Compose([
        transform.Crop(inputHW, crop_type='center', padding=mean, ignore_label=255),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    # training data
    trainData = dataset.SemData(split='train', data_root=rootpath, data_list=trainlist, transform=trans)
    trainDataLoader = torch.utils.data.DataLoader(trainData,
                                                batch_size=160,
                                                shuffle=True,
                                                num_workers=32,
                                                pin_memory=True,
                                                drop_last=True)

    # val data
    valData = dataset.SemData(split='val', data_root=rootpath, data_list=vallist, transform=valTrans)
    valDataLoader = torch.utils.data.DataLoader(valData,
                                                batch_size=4,
                                                shuffle=False,
                                                num_workers=4,
                                                pin_memory=True)
    # return datasets
    return trainDataLoader, valDataLoader
コード例 #2
0
def train(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    teacher_model = None
    if args.teacher_model_path:
        teacher_model = PSPNet(layers=args.teacher_layers,
                               classes=args.classes,
                               zoom_factor=args.zoom_factor)
        kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature)
        args.save_path = os.path.join(args.save_path, kd_path)
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
    if args.arch == 'dct':
        model = DCTNet(layers=args.layers, classes=args.classes, vec_dim=300)
        # modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        # modules_new = [model.cls, model.aux]  # DCT4
        modules_ori = [model.cp, model.sp, model.head]
        modules_new = []
        args.index_split = len(
            modules_ori
        )  # the module after index_split need multiply 10 at learning rate

    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    # args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if teacher_model is not None:
            teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                teacher_model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)  # tensorboardX
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
        if teacher_model is not None:
            logger.info(teacher_model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
                                                          device_ids=[gpu])
        if teacher_model is not None:
            teacher_model = torch.nn.parallel.DistributedDataParallel(
                teacher_model.cuda(), device_ids=[gpu])

    else:
        model = torch.nn.DataParallel(model.cuda())
        if teacher_model is not None:
            teacher_model = torch.nn.DataParallel(teacher_model.cuda())

    if teacher_model is not None:
        checkpoint = torch.load(
            args.teacher_model_path,
            map_location=lambda storage, loc: storage.cuda())
        teacher_model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loading teacher checkpoint '{}'".format(
            args.teacher_model_path))

    if args.use_ohem:
        criterion = OhemCELoss(thresh=0.7,
                               ignore_index=args.ignore_label).cuda(gpu)
    else:
        criterion = nn.CrossEntropyLoss(
            ignore_index=args.ignore_label).cuda(gpu)

    kd_criterion = None
    if teacher_model is not None:
        kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(gpu)

    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight: '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> mp weight found at '{}'".format(args.weight))

    best_mIoU_val = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # Load all tensors onto GPU
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_iter = checkpoint['iteration']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_mIoU_val = checkpoint['best_mIoU_val']
            if main_process():
                logger.info("=> loaded checkpoint '{}' (iteration {})".format(
                    args.resume, checkpoint['iteration']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    ## RGB mean & std
    rgb_mean = [0.485, 0.456, 0.406]
    rgb_mean = [item * value_scale for item in rgb_mean]
    rgb_std = [0.229, 0.224, 0.225]
    rgb_std = [item * value_scale for item in rgb_std]

    # DCT mean & std
    dct_mean = dct_mean_std.train_upscaled_static_mean
    dct_mean = [item * value_scale for item in dct_mean]
    dct_std = dct_mean_std.train_upscaled_static_std
    dct_std = [item * value_scale for item in dct_std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=rgb_mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=rgb_mean,
                       ignore_label=args.ignore_label),
        # transform.GetDctCoefficient(),
        transform.ToTensor(),
        transform.Normalize(mean=rgb_mean, std=rgb_std)
    ])
    train_data = dataset.SemData(split='train',
                                 img_type='rgb',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    # train_transform = transform_rgbdct.Compose([
    #     transform_rgbdct.RandScale([args.scale_min, args.scale_max]),
    #     transform_rgbdct.RandRotate([args.rotate_min, args.rotate_max], padding=rgb_mean, ignore_label=args.ignore_label),
    #     transform_rgbdct.RandomGaussianBlur(),
    #     transform_rgbdct.RandomHorizontalFlip(),
    #     transform_rgbdct.Crop([args.train_h, args.train_w], crop_type='rand', padding=rgb_mean, ignore_label=args.ignore_label),
    #     transform_rgbdct.GetDctCoefficient(),
    #     transform_rgbdct.ToTensor(),
    #     transform_rgbdct.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std,  std_dct=dct_std)])
    # train_data = dataset.SemData(split='train', img_type='rgb&dct', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, \
        shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, \
        sampler=train_sampler, drop_last=True)
    if args.evaluate:
        # val_h = int(args.base_h * args.scale)
        # val_w = int(args.base_w * args.scale)
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=rgb_mean,
                           ignore_label=args.ignore_label),
            # transform.Resize(size=(val_h, val_w)),
            # transform.GetDctCoefficient(),
            transform.ToTensor(),
            transform.Normalize(mean=rgb_mean, std=rgb_std)
        ])
        val_data = dataset.SemData(split='val',
                                   img_type='rgb',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform)
        # val_transform = transform_rgbdct.Compose([
        #     transform_rgbdct.Crop([args.train_h, args.train_w], crop_type='center', padding=rgb_mean, ignore_label=args.ignore_label),
        #     # transform.Resize(size=(val_h, val_w)),
        #     transform_rgbdct.GetDctCoefficient(),
        #     transform_rgbdct.ToTensor(),
        #     transform_rgbdct.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std,  std_dct=dct_std)])
        # val_data = dataset.SemData(split='val', img_type='rgb&dct', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, \
            shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    # Training Loop
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    # aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    # switch to train mode
    model.train()
    if teacher_model is not None:
        teacher_model.eval()
    end = time.time()
    max_iter = args.max_iter

    data_iter = iter(train_loader)
    epoch = 0
    for current_iter in range(args.start_iter, args.max_iter):
        try:
            input, target = next(data_iter)
            if not target.size(0) == args.batch_size:
                raise StopIteration
        except StopIteration:
            epoch += 1
            if args.distributed:
                train_sampler.set_epoch(epoch)
                if main_process():
                    logger.info('train_sampler.set_epoch({})'.format(epoch))
            data_iter = iter(train_loader)
            input, target = next(data_iter)
            # need to update the AverageMeter for new epoch
            main_loss_meter = AverageMeter()
            # aux_loss_meter = AverageMeter()
            loss_meter = AverageMeter()
            intersection_meter = AverageMeter()
            union_meter = AverageMeter()
            target_meter = AverageMeter()
        # measure data loading time
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        # input = [input[0].cuda(non_blocking=True), input[1].cuda(non_blocking=True)]
        target = target.cuda(non_blocking=True)

        # compute output
        # main_out, aux_out = model(input)
        main_out = model(input)
        # _, H, W = target.shape
        # main_out = F.interpolate(main_out, size=(H, W), mode='bilinear', align_corners=True)
        main_loss = criterion(main_out, target)
        # aux_loss = criterion(aux_out, target)

        if not args.multiprocessing_distributed:
            # main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss)
            main_loss = torch.mean(main_loss)
        # loss = main_loss + args.aux_weight * aux_loss
        loss = main_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n = target.size(0)
        # if args.multiprocessing_distributed:
        #     main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n  # not considering ignore pixels
        #     count = target.new_tensor([n], dtype=torch.long)
        #     dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
        #     n = count.item()
        #     main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n
        if args.multiprocessing_distributed:
            main_loss, loss = main_loss.detach(
            ) * n, loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(loss), dist.all_reduce(
                count)
            n = count.item()
            main_loss, loss = main_loss / n, loss / n

        main_out = main_out.detach().max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(
            main_out, target, args.classes, args.ignore_label)
        if args.multiprocessing_distributed:
            dist.all_reduce(intersection), dist.all_reduce(
                union), dist.all_reduce(target)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        # aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        # Using Poly strategy to change the learning rate
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split
                           ):  # args.index_split = 5 -> ResNet has 5 stages
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        iter_log = current_iter + 1
        if iter_log % args.print_freq == 0 and main_process():
            logger.info('Iter [{}/{}] '
                        'LR: {lr:.3e}, '
                        'ETA: {remain_time}, '
                        'Data: {data_time.val:.3f} ({data_time.avg:.3f}), '
                        'Batch: {batch_time.val:.3f} ({batch_time.avg:.3f}), '
                        'MainLoss: {main_loss_meter.val:.4f}, '
                        # 'AuxLoss: {aux_loss_meter.val:.4f}, '
                        'Loss: {loss_meter.val:.4f}, '
                        'Accuracy: {accuracy:.4f}.'.format(
                            iter_log,
                            args.max_iter,
                            lr=current_lr,
                            remain_time=remain_time,
                            data_time=data_time,
                            batch_time=batch_time,
                            main_loss_meter=main_loss_meter,
                            # aux_loss_meter=aux_loss_meter,
                            loss_meter=loss_meter,
                            accuracy=accuracy))
        if main_process():
            writer.add_scalar('loss_train_batch', main_loss_meter.val,
                              iter_log)
            writer.add_scalar('mIoU_train_batch',
                              np.mean(intersection / (union + 1e-10)),
                              iter_log)
            writer.add_scalar('mAcc_train_batch',
                              np.mean(intersection / (target + 1e-10)),
                              iter_log)
            writer.add_scalar('allAcc_train_batch', accuracy, iter_log)

        if iter_log % len(
                train_loader
        ) == 0 or iter_log == max_iter:  # for each epoch or the max interation
            iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
            accuracy_class = intersection_meter.sum / (target_meter.sum +
                                                       1e-10)
            mIoU_train = np.mean(iou_class)
            mAcc_train = np.mean(accuracy_class)
            allAcc_train = sum(
                intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
            loss_train = main_loss_meter.avg
            if main_process():
                logger.info('Train result at iteration [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'\
                    .format(iter_log, max_iter, mIoU_train, mAcc_train, allAcc_train))
                writer.add_scalar('loss_train', loss_train, iter_log)
                writer.add_scalar('mIoU_train', mIoU_train, iter_log)
                writer.add_scalar('mAcc_train', mAcc_train, iter_log)
                writer.add_scalar('allAcc_train', allAcc_train, iter_log)

        # if iter_log % args.save_freq == 0:
            is_best = False
            if args.evaluate:
                loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                    val_loader, model, criterion)
                model.train()  # the mode change from eval() to train()
                if main_process():
                    writer.add_scalar('loss_val', loss_val, iter_log)
                    writer.add_scalar('mIoU_val', mIoU_val, iter_log)
                    writer.add_scalar('mAcc_val', mAcc_val, iter_log)
                    writer.add_scalar('allAcc_val', allAcc_val, iter_log)

                    if best_mIoU_val < mIoU_val:
                        is_best = True
                        best_mIoU_val = mIoU_val
                        logger.info('==>The best val mIoU: %.3f' %
                                    (best_mIoU_val))

            if main_process():
                save_checkpoint(
                    {
                        'iteration': iter_log,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_mIoU_val': best_mIoU_val
                    }, is_best, args.save_path)
                logger.info('Saving checkpoint to:{}/iter_{}.pth or last.pth with mIoU:{:.3f}'\
                    .format(args.save_path, iter_log, mIoU_val))
                if is_best:
                    logger.info('Saving checkpoint to:{}/best.pth with mIoU:{:.3f}'\
                        .format(args.save_path, best_mIoU_val))

    if main_process():
        writer.close(
        )  # it must close the writer, otherwise it will appear the EOFError!
        logger.info(
            '==>Training done! The best val mIoU during training: %.3f' %
            (best_mIoU_val))
コード例 #3
0
def evaluate(respth='./res',
             dspth='/data2/.encoding/data/cityscapes',
             checkpoint=None):
    args = get_parser()
    ## logger
    logger = get_logger()

    ## model
    logger.info('\n')
    logger.info('====' * 20)
    logger.info('evaluating the model ...\n')
    logger.info('setup and restore model')
    n_classes = 19
    net = FANet(layers=18, classes=n_classes)

    # if checkpoint is None:
    #     save_pth = osp.join(respth, 'model_final.pth')
    # else:
    #     save_pth = checkpoint

    # net.load_state_dict(torch.load(save_pth))

    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        # checkpoint = torch.load(args.model_path, map_location=torch.device('cpu'))
        checkpoint = torch.load(args.model_path)
        net.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))

    net.cuda()
    net.eval()

    ## dataset
    # batchsize = 1
    # n_workers = 2
    # dsval = CityScapes(dspth, mode='val')
    # dl = DataLoader(dsval,
    #                 batch_size = batchsize,
    #                 shuffle = False,
    #                 num_workers = n_workers,
    #                 drop_last = False)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    val_transform = transform.Compose([
        # transform.Crop([args.test_h, args.test_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])
    val_data = dataset.SemData(split='val',
                               data_root=args.data_root,
                               data_list=args.val_list,
                               transform=val_transform)
    dl = torch.utils.data.DataLoader(val_data,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=4,
                                     pin_memory=True)

    ## evaluator
    logger.info('compute the mIOU')
    evaluator = MscEval(net, dl, scales=[1.0], flip=False)
    ## eval
    mIOU = evaluator.evaluate()
    logger.info('mIOU is: {:.6f}'.format(mIOU))
コード例 #4
0
def evaluate():
    args = get_parser()
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    if args.arch == 'psp':
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor)
    elif args.arch == 'nonlocal':
        model = Nonlocal(layers=args.layers,
                         classes=args.classes,
                         zoom_factor=args.zoom_factor)
    elif args.arch == 'danet':
        model = DANet(layers=args.layers,
                      classes=args.classes,
                      zoom_factor=args.zoom_factor)
    elif args.arch == 'sanet':
        model = SANet(layers=args.layers,
                      classes=args.classes,
                      zoom_factor=args.zoom_factor)
    elif args.arch == 'fanet':
        model = FANet(layers=args.layers, classes=args.classes)
    elif args.arch == 'fftnet':
        model = FFTNet(layers=args.layers, classes=args.classes)
    elif args.arch == 'fftnet_23':
        model = FFTNet23(layers=args.layers, classes=args.classes)
    elif args.arch == 'bise_v1':
        model = BiseNet(layers=args.layers,
                        classes=args.classes,
                        with_sp=args.with_sp)
    elif args.arch == 'dct':
        model = DCTNet(layers=args.layers, classes=args.classes, vec_dim=300)
    elif args.arch == 'triple':
        model = TriSeNet(layers=args.layers, classes=args.classes)
    elif args.arch == 'triple_1':
        model = TriSeNet1(layers=args.layers, classes=args.classes)
    elif args.arch == 'ppm':
        model = PPM_Net(backbone=args.backbone,
                        layers=args.layers,
                        classes=args.classes)
    elif args.arch == 'fc':
        model = FC_Net(backbone=args.backbone,
                       layers=args.layers,
                       classes=args.classes)

    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        # checkpoint = torch.load(args.model_path, map_location=torch.device('cpu'))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))

    value_scale = 255
    ## RGB mean & std
    rgb_mean = [0.485, 0.456, 0.406]
    rgb_mean = [item * value_scale for item in rgb_mean]
    rgb_std = [0.229, 0.224, 0.225]
    rgb_std = [item * value_scale for item in rgb_std]

    # DCT mean & std
    dct_mean = dct_mean_std.train_upscaled_static_mean
    dct_mean = [item * value_scale for item in dct_mean]
    dct_std = dct_mean_std.train_upscaled_static_std
    dct_std = [item * value_scale for item in dct_std]

    val_h = int(args.base_h * args.scale)
    val_w = int(args.base_w * args.scale)

    val_transform = transform.Compose([
        transform.Resize(size=(val_h, val_w)),
        # transform.GetDctCoefficient(),
        transform.ToTensor(),
        transform.Normalize(mean=rgb_mean, std=rgb_std)
    ])
    val_data = dataset.SemData(split='val',
                               img_type='rgb',
                               data_root=args.data_root,
                               data_list=args.val_list,
                               transform=val_transform)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True)

    # val_transform = transform.Compose([
    #     # transform.Resize(size=(val_h, val_w)),
    #     transform.GetDctCoefficient(),
    #     transform.ToTensor(),
    #     transform.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std, std_dct=dct_std)])
    # val_data = dataset.SemData(split='val', img_type='rgb&dct', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
    # val_loader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

    # test_transform = transform.Compose([
    #     transform.ToTensor(),
    #     transform.Normalize(mean=mean, std=std)])
    # # test_transform = transform.Compose([transform.ToTensor()])
    # test_data = dataset.SemData(split='test', data_root=args.data_root, data_list=args.test_list, transform=test_transform)
    # val_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

    logger.info('>>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>>')
    batch_time = AverageMeter()
    data_time = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.eval()
    end = time.time()
    results = []

    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            data_time.update(time.time() - end)
            # _, _, H, W = input.shape
            input = input.cuda(non_blocking=True)
            # input = [input[0].cuda(non_blocking=True), input[1].cuda(non_blocking=True)]
            target = target.cuda(non_blocking=True)
            # if args.scale != 1.0:
            # input = F.interpolate(input, size=(val_h, val_w), mode='bilinear', align_corners=True)
            if args.teacher_model_path != None and args.arch == 'sanet':
                output, _, _ = model(input)
            else:
                output = model(input)
                # if args.scale != 1.0:
                # output = F.interpolate(output, size=(H, W), mode='bilinear', align_corners=True)
            _, H, W = target.shape
            # output = F.interpolate(output, size=(H, W), mode='bilinear', align_corners=True)
            output = output.detach().max(1)[1]
            results.append(output.cpu().numpy().reshape(H, W))
            intersection, union, target = intersectionAndUnionGPU(
                output, target, args.classes, args.ignore_label)
            intersection, union, target = intersection.cpu().numpy(
            ), union.cpu().numpy(), target.cpu().numpy()
            intersection_meter.update(intersection), union_meter.update(
                union), target_meter.update(target)

            accuracy = sum(
                intersection_meter.val) / (sum(target_meter.val) + 1e-10)
            batch_time.update(time.time() - end)
            end = time.time()
            if ((i + 1) % 10 == 0) or (i + 1 == len(val_loader)):
                logger.info(
                    'Val: [{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                    'Accuracy {accuracy:.4f}.'.format(i + 1,
                                                      len(val_loader),
                                                      data_time=data_time,
                                                      batch_time=batch_time,
                                                      accuracy=accuracy))

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(
        mIoU, mAcc, allAcc))
    for i in range(args.classes):
        logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(
            i, iou_class[i], accuracy_class[i]))
    logger.info('<<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<<<<<')
コード例 #5
0
def test():
    args = get_parser()
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    test_h = int(args.base_h * args.scale)
    test_w = int(args.base_w * args.scale)

    test_transform = transform.Compose(
        [transform.ToTensor(),
         transform.Normalize(mean=mean, std=std)])
    # test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split='test',
                                data_root=args.data_root,
                                data_list=args.test_list,
                                transform=test_transform)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    if args.arch == 'psp':
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor)
    elif args.arch == 'nonlocal':
        model = Nonlocal(layers=args.layers,
                         classes=args.classes,
                         zoom_factor=args.zoom_factor)
    elif args.arch == 'danet':
        model = DANet(layers=args.layers,
                      classes=args.classes,
                      zoom_factor=args.zoom_factor)
    elif args.arch == 'sanet':
        model = SANet(layers=args.layers,
                      classes=args.classes,
                      zoom_factor=args.zoom_factor)
    elif args.arch == 'fanet':
        model = FANet(layers=args.layers, classes=args.classes)
    elif args.arch == 'bise_v1':
        model = BiseNet(layers=args.layers,
                        classes=args.classes,
                        with_sp=args.with_sp)
    elif args.arch == 'dct':
        model = DCTNet(layers=args.layers,
                       classes=args.classes,
                       use_dct=True,
                       vec_dim=300)
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        # checkpoint = torch.load(args.model_path, map_location=torch.device('cpu'))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'])
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))

    logger.info('>>>>>>>>>>>>>>>>> Start Testing >>>>>>>>>>>>>>>>>')
    data_time = AverageMeter()
    batch_time = AverageMeter()

    model.eval()
    end = time.time()
    results = []

    with torch.no_grad():
        for i, (input, _) in enumerate(test_loader):
            data_time.update(time.time() - end)
            input = np.squeeze(input.numpy(), axis=0)
            image = np.transpose(input, (1, 2, 0))
            h, w, _ = image.shape
            prediction = np.zeros((h, w, args.classes), dtype=float)

            input = torch.from_numpy(image.transpose((2, 0, 1))).float()
            # for t, m, s in zip(input, mean, std):
            #     t.sub_(m).div_(s)
            input = input.unsqueeze(0).cuda()
            # if args.scale != 1.0:
            #     input = F.interpolate(input, size=(new_H, new_W), mode='bilinear', align_corners=True)
            if args.teacher_model_path != None and args.arch == 'sanet':
                output, _, _ = model(input)
            else:
                output = model(input)
                # if args.scale != 1.0:
                #     output = F.interpolate(output, size=(H, W), mode='bilinear', align_corners=True)
            output = F.softmax(output, dim=1)
            output = output[0]
            output = output.data.cpu().numpy()
            prediction = output.transpose(1, 2, 0)
            prediction = np.argmax(prediction, axis=2)
            results.append(prediction)
            batch_time.update(time.time() - end)
            end = time.time()
            if ((i + 1) % 10 == 0) or (i + 1 == len(test_loader)):
                logger.info(
                    'Test: [{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}).'.
                    format(i + 1,
                           len(test_loader),
                           data_time=data_time,
                           batch_time=batch_time))

    logger.info('<<<<<<<<<<<<<<<<<< End Testing <<<<<<<<<<<<<<<<<<<<<')

    logger.info('Convert to Label ID')
    result_files = dataset.results2img(results=results,
                                       data_root=args.data_root,
                                       data_list=args.test_list,
                                       save_dir='./val_result',
                                       to_label_id=True)
    logger.info('Convert to Label ID Finished')
コード例 #6
0
import torch.utils.data
from torch.utils.data import DataLoader
from torchvision import transforms

from utils import dataset, transform

data_dir = './dataset'
bs = 8
mean = [58.6573, 65.9755, 56.4990, 100.8296]
std = [60.2980, 59.0457, 58.0989, 47.1224]
train_transform = transform.Compose(
    [transform.ToTensor(),
     transform.Normalize(mean=mean, std=std)])

train_data = dataset.SemData(split='train',
                             data_root=data_dir,
                             transform=train_transform)

train_loader = DataLoader(dataset=train_data,
                          batch_size=bs,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True,
                          drop_last=False)

data, label, _, _ = next(iter(train_loader))
for i in range(len(data)):
    print(data[i].mean(), data[i].std())


# https://towardsdatascience.com/how-to-calculate-the-mean-and-standard-deviation-normalizing-datasets-in-pytorch-704bd7d05f4c
コード例 #7
0
def main_worker(local_rank, ngpus_per_node, argss):
    global args
    args = argss

    dist.init_process_group(backend=args.dist_backend)

    teacher_model = None
    if args.teacher_model_path:
        teacher_model = PSPNet(layers=args.teacher_layers, classes=args.classes, zoom_factor=args.zoom_factor)
        kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature)
        args.save_path = os.path.join(args.save_path, kd_path)
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
    if args.arch == 'psp':
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'bise_v1':
        model = BiseNet(num_classes=args.classes)
        modules_ori = [model.sp, model.cp]
        modules_new = [model.ffm, model.conv_out, model.conv_out16, model.conv_out32]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if teacher_model is not None:
            teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path) # tensorboardX
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
        if teacher_model is not None:
            logger.info(teacher_model)
    if args.distributed:
        torch.cuda.set_device(local_rank)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank])
        if teacher_model is not None:
            teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model.cuda(), device_ids=[local_rank])

    else:
        model = torch.nn.DataParallel(model.cuda())
        if teacher_model is not None:
            teacher_model = torch.nn.DataParallel(teacher_model.cuda())
    
    if teacher_model is not None:
        checkpoint = torch.load(args.teacher_model_path, map_location=lambda storage, loc: storage.cuda())
        teacher_model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loading teacher checkpoint '{}'".format(args.teacher_model_path))
    
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(local_rank)
    kd_criterion = None
    if teacher_model is not None:
        kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(local_rank)
            
    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight: '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> mp weight found at '{}'".format(args.weight))
    
    best_mIoU_val = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # Load all tensors onto GPU
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_mIoU_val = checkpoint['best_mIoU_val']
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['point']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(args.resume))    
        
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
        
    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)])

    train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)])
        val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            # Use .set_epoch() method to reshuffle the dataset partition at every iteration
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(local_rank, train_loader, model, teacher_model, criterion, kd_criterion, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
        
        is_best = False
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(local_rank, val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)

                if best_mIoU_val < mIoU_val:
                    is_best = True
                    best_mIoU_val = mIoU_val
                    logger.info('==>The best val mIoU: %.3f' % (best_mIoU_val))

        
        if (epoch_log % args.save_freq == 0) and main_process():
            save_checkpoint(
                {
                    'epoch': epoch_log, 
                    'state_dict': model.state_dict(), 
                    'optimizer': optimizer.state_dict(),
                    'best_mIoU_val': best_mIoU_val
                }, 
                is_best, 
                args.save_path
            )
            if is_best:
                logger.info('Saving checkpoint to:' + args.save_path + '/best.pth with mIoU: ' + str(best_mIoU_val) )
            else:
                logger.info('Saving checkpoint to:' + args.save_path + '/last.pth with mIoU: ' + str(mIoU_val) )

    if main_process():  
        writer.close() # it must close the writer, otherwise it will appear the EOFError!
        logger.info('==>Training done!\nBest mIoU: %.3f' % (best_mIoU_val))
コード例 #8
0
ファイル: test.py プロジェクト: Jun-jieChen/Real-time-Semseg
def main():
    global args, logger
    args = get_parser()
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    test_h = int(args.base_h * args.scale + 1)
    test_w = int(args.base_w * args.scale + 1)
    
    if args.teacher_model_path:
        kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature)
        kd_save = kd_path + '/val/ss'
        args.save_folder = os.path.join(args.save_folder, kd_save)
        args.save_path = os.path.join(args.save_path, kd_path)
        args.model_path = os.path.join(args.save_path, args.model_path)

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split='test', data_root=args.data_root, data_list=args.test_list, transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step, len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = [line.rstrip('\n') for line in open(args.names_path)]

    if not args.has_prediction:
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor)
        elif args.arch == 'nonlocal':
            from model.nonlocal_net import Nonlocal
            model = Nonlocal(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor)
        elif args.arch == 'sanet':
            from model.sanet import SANet
            model = SANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor)
        elif args.arch == 'bise_v1':
            from model.bisenet_v1 import BiseNet
            model = BiseNet(layers=args.layers, classes=args.classes, with_sp=args.with_sp)
        elif args.arch == 'fanet':
            from model.fanet import FANet
            model = FANet(layers=args.layers, classes=args.classes)
        logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.model_path):
            logger.info("=> loading checkpoint '{}'".format(args.model_path))
            checkpoint = torch.load(args.model_path)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded checkpoint '{}'".format(args.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
        test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, test_h, test_w, args.scales, gray_folder, color_folder, colors)
    if args.split != 'test':
        cal_acc(test_data.data_list, gray_folder, args.classes, names)
コード例 #9
0
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([train_h, train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=255),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    val_transform = transform.Compose(
        [transform.ToTensor(),
         transform.Normalize(mean=mean, std=std)])

    train_data = dataset.SemData(split='train',
                                 data_root=data_dir,
                                 transform=train_transform)
    val_data = dataset.SemData(split='val',
                               data_root=data_dir,
                               transform=val_transform)
    test_data = dataset.SemData(split='test',
                                data_root=data_dir,
                                transform=val_transform)

    train_loader = DataLoader(dataset=train_data,
                              batch_size=1,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)
    val_loader = DataLoader(dataset=val_data,