예제 #1
0
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    val_transform = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.Crop([args.crop_h, args.crop_w],
                        crop_type='center',
                        padding=mean,
                        ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    val_data1 = datasets.SegData(split=args.split,
                                 data_root=args.data_root,
                                 data_list=args.val_list1,
                                 transform=val_transform)
    val_loader1 = torch.utils.data.DataLoader(val_data1,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    model_ppm = PPM().cuda()
    model_ppm = torch.nn.DataParallel(model_ppm)

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   use_softmax=True,
                   pretrained=False,
                   syncbn=False).cuda()
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))

    checkpoint_ppm = torch.load(args.model_path.replace('.pth', '_ppm.pth'))
    model_ppm.load_state_dict(checkpoint_ppm['state_dict'], strict=False)

    cv2.setNumThreads(0)

    validate(val_loader1, val_data1.data_list, model, model_ppm)
예제 #2
0
        dataset = CityscapesSemanticSegmentationDataset(
            args.cityscapes_img_dir, None, args.split)
    elif args.model == 'ADE20K':
        n_class = 150
        n_blocks = [3, 4, 6, 3]
        feat_size = 60
        mid_stride = False
        param_fn = 'weights/pspnet101_ADE20K_473_reference.chainer'
        base_size = 512
        crop_size = 473

    dataset = TransformDataset(dataset, preprocess)
    print('dataset:', len(dataset))

    chainer.config.train = False
    model = PSPNet(n_class, n_blocks, feat_size, mid_stride=mid_stride)
    serializers.load_npz(param_fn, model)
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu)
        model.to_gpu(args.gpu)

    for i in tqdm(range(args.start_i, args.end_i + 1)):
        img = dataset[i]
        out_fn = os.path.join(
            args.out_dir, os.path.basename(dataset._dataset.img_fns[i]))
        pred = inference(
            model, n_class, base_size, crop_size, img, args.scales)
        assert pred.ndim == 2

        if args.model == 'Cityscapes':
            if args.color_out_dir is not None:
예제 #3
0
import argparse
import numpy as np
import h5py
from scipy import misc

from pspnet import PSPNet
import utils_run as utils

parser = argparse.ArgumentParser()
parser.add_argument("-p", required=True, help="Project name")
parser.add_argument('--id', default=0, type=int)
parser.add_argument('--local', action='store_true', default=False)
args = parser.parse_args()

project = args.p
pspnet = PSPNet(DEVICE=args.id)
pspnet.print_network_architecture()

CONFIG = utils.get_config(project)
im_list = utils.open_im_list(project)

root_images = CONFIG["images"]
root_result = CONFIG["pspnet_prediction"]
if args.local:
    root_result = "pspnet_prediction_tmp/"

root_mask = os.path.join(root_result, 'category_mask')
root_prob = os.path.join(root_result, 'prob_mask')
root_maxprob = os.path.join(root_result, 'max_prob')
root_allprob = os.path.join(root_result, 'all_prob')
예제 #4
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    #if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.net_type in [0, 1, 2, 3]

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank,
                                                world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    if args.net_type == 0:
        from pspnet import PSPNet
        model = PSPNet(backbone=args.backbone,
                       layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       syncbn=args.syncbn,
                       group_size=args.bn_group,
                       group=args.bn_group_comm).cuda()
    elif args.net_type in [1, 2, 3]:
        from pspnet_div4 import PSPNet
        model = PSPNet(backbone=args.backbone,
                       layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       syncbn=args.syncbn,
                       group_size=args.bn_group,
                       group=args.bn_group_comm,
                       net_type=args.net_type).cuda()
    logger.info(model)
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    if args.net_type == 0:
        optimizer = torch.optim.SGD([{
            'params': model.layer0.parameters()
        }, {
            'params': model.layer1.parameters()
        }, {
            'params': model.layer2.parameters()
        }, {
            'params': model.layer3.parameters()
        }, {
            'params': model.layer4.parameters()
        }, {
            'params': model.ppm.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.conv6.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.conv1_1x1.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.aux.parameters(),
            'lr': args.base_lr * 10
        }],
                                    lr=args.base_lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.net_type == 1:
        optimizer = torch.optim.SGD([{
            'params': model.layer0.parameters()
        }, {
            'params': model.layer1.parameters()
        }, {
            'params': model.layer2.parameters()
        }, {
            'params': model.layer3.parameters()
        }, {
            'params': model.layer4.parameters()
        }, {
            'params': model.layer4_p.parameters()
        }, {
            'params': model.ppm.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.ppm_p.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls_p.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.aux.parameters(),
            'lr': args.base_lr * 10
        }],
                                    lr=args.base_lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.net_type == 2:
        optimizer = torch.optim.SGD([{
            'params': model.layer0.parameters()
        }, {
            'params': model.layer1.parameters()
        }, {
            'params': model.layer2.parameters()
        }, {
            'params': model.layer3.parameters()
        }, {
            'params': model.layer4.parameters()
        }, {
            'params': model.layer4_p.parameters()
        }, {
            'params': model.ppm.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.ppm_p.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls_p.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.att.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.att_p.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.aux.parameters(),
            'lr': args.base_lr * 10
        }],
                                    lr=args.base_lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.net_type == 3:
        optimizer = torch.optim.SGD([{
            'params': model.layer0.parameters()
        }, {
            'params': model.layer1.parameters()
        }, {
            'params': model.layer2.parameters()
        }, {
            'params': model.layer3.parameters()
        }, {
            'params': model.layer4.parameters()
        }, {
            'params': model.layer4_p.parameters()
        }, {
            'params': model.ppm.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.ppm_p.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls_p.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.att.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.aux.parameters(),
            'lr': args.base_lr * 10
        }],
                                    lr=args.base_lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    #model = torch.nn.DataParallel(model).cuda()
    model = DistModule(model)
    #if args.syncbn:
    #    from lib.syncbn import patch_replication_callback
    #    patch_replication_callback(model)
    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label).cuda()

    if args.weight:

        def map_func(storage, location):
            return storage.cuda()

        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        load_state(args.resume, model, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([args.scale_min, args.scale_max]),
        #transforms.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w],
                        crop_type='rand',
                        padding=mean,
                        ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    train_data = datasets.SegData(split='train',
                                  data_root=args.data_root,
                                  data_list=args.train_list,
                                  transform=train_transform)
    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Crop([args.crop_h, args.crop_w],
                            crop_type='center',
                            padding=mean,
                            ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])
        val_data = datasets.SegData(split='val',
                                    data_root=args.data_root,
                                    data_list=args.val_list,
                                    transform=val_transform)
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, criterion, optimizer, epoch, args.zoom_factor,
            args.batch_size, args.aux_weight)
        if rank == 0:
            writer.add_scalar('loss_train', loss_train, epoch)
            writer.add_scalar('mIoU_train', mIoU_train, epoch)
            writer.add_scalar('mAcc_train', mAcc_train, epoch)
            writer.add_scalar('allAcc_train', allAcc_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
            #if epoch / args.save_step > 2:
            #    deletename = args.save_path + '/train_epoch_' + str(epoch - args.save_step*2) + '.pth'
            #    os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion, args.classes, args.zoom_factor)
            writer.add_scalar('loss_val', loss_val, epoch)
            writer.add_scalar('mIoU_val', mIoU_val, epoch)
            writer.add_scalar('mAcc_val', mAcc_val, epoch)
            writer.add_scalar('allAcc_val', allAcc_val, epoch)
    conv10 = Conv2D(3, 1, activation = 'softmax')(conv9)

    model = Model(input = inputs, output = conv10)

    model.compile(optimizer = Adam(lr = 0.000009), loss = 'categorical_crossentropy', metrics = ['accuracy'])
    
    model.summary()
    
    filelist_modelweights = sorted(glob.glob('*.h5'), key=numericalSort)
    
    if 'model_nocropping.h5' in filelist_modelweights:
        model.load_weights('model_nocropping.h5')
    return model
'''

model = PSPNet()

# Data Augmentation

datagen_args = dict(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='nearest')

#x_datagen = ImageDataGenerator(**datagen_args)
#y_datagen = ImageDataGenerator(**datagen_args)

#seed = 1
예제 #6
0
    parser.add_argument("--outputs",
                        type=str,
                        default="outputs/",
                        help="保存结果的文件夹的路径")
    args = parser.parse_args()
    print(args)

    CLASS_NUM = args.class_num
    WEIGHTS = args.weights
    COLORS = args.colors
    SAMPLES = args.samples
    OUTPUTS = args.outputs

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = PSPNet(num_classes=CLASS_NUM,
                   downsample_factor=16,
                   pretrained=False,
                   aux_branch=False).to(device=device)
    print('model structure is: ')
    print(model)

    model.load_state_dict(torch.load(WEIGHTS, map_location=device))
    model.eval()

    # PyTorch转ONNX
    print('=================================')
    dummy_input = torch.randn(1, 3, 473, 473).to(device)
    torch.onnx.export(model,
                      dummy_input,
                      'pspnet.onnx',
                      dynamic_axes={
                          'image': {
예제 #7
0
                          args.batch_size_train,
                          shuffle=True,
                          drop_last=True,
                          num_workers=args.num_workers,
                          pin_memory=not args.no_cuda)
loader_valid = DataLoader(dataset_valid,
                          args.batch_size_valid,
                          shuffle=False,
                          drop_last=False,
                          num_workers=args.num_workers,
                          pin_memory=not args.no_cuda)

models = {
    'squeezenet':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=512,
                   deep_features_size=256,
                   backend='squeezenet'),
    'densenet':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=1024,
                   deep_features_size=512,
                   backend='densenet'),
    'resnet18':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=512,
                   deep_features_size=256,
                   backend='resnet18'),
    'resnet34':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=512,
                   deep_features_size=256,
예제 #8
0
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.autograd import Variable
from torch.utils.data import DataLoader

from tqdm import tqdm
import click
import numpy as np

from pspnet import PSPNet

models = {
    'resnet50':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=2048,
                   deep_features_size=1024,
                   backend='resnet50'),
    'resnet101':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=2048,
                   deep_features_size=1024,
                   backend='resnet101'),
    'resnet152':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=2048,
                   deep_features_size=1024,
                   backend='resnet152')
}


def build_network(snapshot, backend):
예제 #9
0
def main():
    net = PSPNet(19)
    net.load_pretrained_model(
        model_path='./Caffe-PSPNet/pspnet101_cityscapes.caffemodel')
    for param in net.parameters():
        param.requires_grad = False
    net.cbr_final = conv2DBatchNormRelu(4096, 128, 3, 1, 1, False)
    net.dropout = nn.Dropout2d(p=0.1, inplace=True)
    net.classification = nn.Conv2d(128, kitti_binary.num_classes, 1, 1, 0)

    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in net.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in net.parameters()
                                 if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        args['best_record'] = {
            'epoch': 0,
            'iter': 0,
            'val_loss': 1e10,
            'accu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'iter': int(split_snapshot[3]),
            'val_loss': float(split_snapshot[5]),
            'accu': float(split_snapshot[7])
        }
    net.cuda(args['gpu']).train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    train_set = kitti_binary.KITTI(mode='train',
                                   joint_transform=train_joint_transform,
                                   transform=train_input_transform,
                                   target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = kitti_binary.KITTI(mode='val',
                                 transform=val_input_transform,
                                 target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=8,
                            shuffle=False)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 1.05)).cuda(
        args['gpu'])

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, args, val_loader)