Beispiel #1
0
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.split in ['train', 'val', 'test']
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    val_transform = transforms.Compose([
        transforms.Resize((args.crop_h, args.crop_w)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    val_data1 = datasets.SegData(split=args.split,
                                 data_root=args.data_root,
                                 data_list=args.val_list1,
                                 transform=val_transform)
    val_loader1 = torch.utils.data.DataLoader(val_data1,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   use_softmax=False,
                   use_aux=False,
                   pretrained=False,
                   syncbn=False).cuda()

    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))
    cv2.setNumThreads(0)

    fff = open('/mnt/lustre/share/dingmingyu/cityscapes/instance_list_new.txt'
               ).readlines()
    flag = []
    for i, line in enumerate(fff):
        if i > 100:
            break
        img = line.strip().split()[0]
        img = cv2.imread(img)
        if 'ignore' in line.strip().split()[1]:
            gt = cv2.imread(line.strip().split()[2])
            flag.append(1)
        else:
            flag.append(0)
            gt = cv2.imread(line.strip().split()[1])

        cv2.imwrite('result/result_%d_gt.png' % i, gt)
        cv2.imwrite('result/result_%d_ori.png' % i, img)

    validate(val_loader1, val_data1.data_list, model, args.classes, mean, std,
             args.base_size1, args.crop_h, args.crop_w, flag)
Beispiel #2
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    #if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank, world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm).cuda()
    logger.info(model)
    model_ppm = PPM().cuda()
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD(
        [{'params': model.layer0.parameters()},
         {'params': model.layer1.parameters()},
         {'params': model.layer2.parameters()},
         {'params': model.layer3.parameters()},
         {'params': model.layer4_ICR.parameters()},
         {'params': model.layer4_PFR.parameters()},
         {'params': model.layer4_PRP.parameters()},
         {'params': model_ppm.cls_trans.parameters(), 'lr': args.base_lr * 10},
         {'params': model_ppm.cls_quat.parameters(), 'lr': args.base_lr * 10}
        ],
        lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)

    #model = torch.nn.DataParallel(model).cuda()
    model = DistModule(model)
    model_ppm = DistModule(model_ppm)
    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.L1Loss().cuda()

    if args.weight:
        def map_func(storage, location):
            return storage.cuda()
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        load_state(args.resume, model, model_ppm, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.Resize(size=(256,256)),
        #transforms.RandomGaussianBlur(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ColorJitter([0.4,0.4,0.4]),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])
    
    train_data = datasets.SegData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Resize(size=(256,256)),
            transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        val_data = datasets.SegData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        val_sampler = DistributedSampler(val_data)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs + 1):
        t_loss_train, r_loss_train= train(train_loader, model, model_ppm, criterion, optimizer, epoch, args.zoom_factor, args.batch_size, args.aux_weight)
        if rank == 0:
            writer.add_scalar('t_loss_train', t_loss_train, epoch)
            writer.add_scalar('r_loss_train', r_loss_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            filename_ppm = args.save_path + '/train_epoch_' + str(epoch) + '_ppm.pth'
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
            torch.save({'epoch': epoch, 'state_dict': model_ppm.state_dict(), 'optimizer': optimizer.state_dict()}, filename_ppm)
            #if epoch / args.save_step > 2:
            #    deletename = args.save_path + '/train_epoch_' + str(epoch - args.save_step*2) + '.pth'
            #    os.remove(deletename)
        if args.evaluate:
            t_loss_val, r_loss_val= validate(val_loader, model, model_ppm, criterion)
            writer.add_scalar('t_loss_val', t_loss_val, epoch)
            writer.add_scalar('r_loss_val', r_loss_val, epoch)
    writer.close()
Beispiel #3
0
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    val_transform = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.Crop([args.crop_h, args.crop_w],
                        crop_type='center',
                        padding=mean,
                        ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    val_data1 = datasets.SegData(split=args.split,
                                 data_root=args.data_root,
                                 data_list=args.val_list1,
                                 transform=val_transform)
    val_loader1 = torch.utils.data.DataLoader(val_data1,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    model_pfr = PFR().cuda()
    model_pfr = torch.nn.DataParallel(model_pfr)
    model_prp = PRP().cuda()
    model_prp = torch.nn.DataParallel(model_prp)

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   use_softmax=True,
                   pretrained=False,
                   syncbn=False).cuda()
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))

    checkpoint_pfr = torch.load(args.model_path.replace('.pth', '_pfr.pth'))
    checkpoint_prp = torch.load(args.model_path.replace('.pth', '_prp.pth'))
    model_pfr.load_state_dict(checkpoint_pfr['state_dict'], strict=False)
    model_prp.load_state_dict(checkpoint_prp['state_dict'], strict=False)

    cv2.setNumThreads(0)

    validate(val_loader1, val_data1.data_list, model, model_pfr, model_prp)