def __init__(self, config_file=CONFIG_FILE):
        # Load Parameters
        self.args_ = config.load_cfg_from_cfg_file(config_file)
        self.logger_ = get_logger()
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
            str(x) for x in self.args_.test_gpu)
        value_scale = 255
        mean = [0.485, 0.456, 0.406]
        self.mean_ = [item * value_scale for item in mean]
        std = [0.229, 0.224, 0.225]
        self.std_ = [item * value_scale for item in std]
        # self.colors_ = np.loadtxt(self.args_.colors_path).astype('uint8')

        # Load Model
        if self.args_.arch == 'psp':
            from model.pspnet import PSPNet
            self.model_ = PSPNet(layers=self.args_.layers,
                                 classes=self.args_.classes,
                                 zoom_factor=self.args_.zoom_factor,
                                 pretrained=False)
        elif self.args_.arch == 'psa':
            from model.psanet import PSANet
            self.model_ = PSANet(
                layers=self.args_.layers,
                classes=self.args_.classes,
                zoom_factor=self.args_.zoom_factor,
                compact=self.args_.compact,
                shrink_factor=self.args_.shrink_factor,
                mask_h=self.args_.mask_h,
                mask_w=self.args_.mask_w,
                normalization_factor=self.args_.normalization_factor,
                psa_softmax=self.args_.psa_softmax,
                pretrained=False)
        self.model_ = torch.nn.DataParallel(self.model_).cuda()
        cudnn.benchmark = True

        if os.path.isfile(self.args_.model_path):
            self.logger_ = get_logger().info(
                "=> loading checkpoint '{}'".format(self.args_.model_path))
            checkpoint = torch.load(self.args_.model_path)
            self.model_.load_state_dict(checkpoint['state_dict'], strict=False)
            self.logger_ = get_logger().info(
                "=> loaded checkpoint '{}'".format(self.args_.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args_.model_path))
Esempio n. 2
0
File: test.py Progetto: whrws/semseg
def main():
    global args, logger
    args = get_parser()
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split=args.split, data_root=args.data_root, data_list=args.test_list, transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step, len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = [line.rstrip('\n') for line in open(args.names_path)]

    if not args.has_prediction:
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False)
        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact,
                           shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False)
        logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.model_path):
            logger.info("=> loading checkpoint '{}'".format(args.model_path))
            checkpoint = torch.load(args.model_path)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            logger.info("=> loaded checkpoint '{}'".format(args.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
        test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, gray_folder, color_folder, colors)
    if args.split != 'test':
        cal_acc(test_data.data_list, gray_folder, args.classes, names)
def main():
    global args, logger
    args = get_parser()
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    colors = np.loadtxt(args.colors_path).astype('uint8')

    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False)
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact,
                       shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False)
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = False #True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
    with open(args.image) as f:
        image_files = f.read().splitlines()
        for file in image_files:
            image = file.split()
            image = os.path.join('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/', image[0])
            test(model.eval(), image, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, colors)

    if (args.image).split('/')[-1] == 'training.txt'
        train_label_list = os.listdir('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/train_label')
        with open('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/training.txt', 'w') as f:
            for label in train_label_list:
                f.write('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/train_label/'+label+'\n')
    else:
        val_label_list = os.listdir('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/val_label/')
        with open('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/validation.txt', 'w') as f:
            for label in val_label_list:
                f.write('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/val_label/'+label+'\n')
Esempio n. 4
0
def main():
    global args, logger
    args = get_parser()
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    colors = np.loadtxt(args.colors_path).astype('uint8')

    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       pretrained=False)
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       pretrained=False)
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))
    paths = glob.glob(args.image + '/scene*/color/*00.jpg')
    for path in paths:
        test(model.eval(), path, args.classes, mean, std, args.base_size,
             args.test_h, args.test_w, args.scales, colors)
Esempio n. 5
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss

    ## step.1 设置分布式相关参数
    # 1.1 分布式初始化
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)  # 分布式初始化

    ## step.2 构建网络
    # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
    criterion = nn.CrossEntropyLoss(
        ignore_index=args.ignore_label)  # 交叉熵损失函数, 根据情况自己修改
    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       psa_type=args.psa_type,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.psa, model.cls, model.aux]
    # ---------------------------------------------------- END ---------------------------------------------------#

    ## step.3 设置优化器
    params_list = []  # 模型参数列表
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(),
                                lr=args.base_lr))  # 原来backbone网络 学习率 0.01
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(),
                 lr=args.base_lr * 10))  # 新加入预测网络 学习率 0.1
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)  # SGD优化器
    # 3.x 设置sync_bn from torch.nn.SyncBatchNorm
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    ## step.4 多线程分布式工作
    # 4.1 判断是否是在主进程中, 如果在进行如下程序
    if main_process():
        global logger, writer
        logger = get_logger()  # 设置logger
        writer = SummaryWriter(args.save_path)  # 设置writer
        logger.info(args)  # 输出参数列表
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)  # 输出网络列表
    # 4.2 分布式工作
    if args.distributed:
        torch.cuda.set_device(gpu)  # 指定编号为gpu的那一张显卡
        args.batch_size = int(args.batch_size /
                              ngpus_per_node)  # 每张卡的训练的batch size
        args.batch_size_val = int(args.batch_size_val /
                                  ngpus_per_node)  # 每张卡的评测的batch size
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)  # 每张卡工作的数目
        model = torch.nn.parallel.DistributedDataParallel(
            model.cuda(), device_ids=[gpu])  # 加载torch分布式
    else:
        model = torch.nn.DataParallel(model.cuda())  # 数据并行

    ## step.5 加载网络权重
    # 5.1 直接加载网络预权重
    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> no weight found at '{}'".format(args.weight))
    # 5.2 加载上次没训练完的模型权重
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    ## step.7 设置数据loader
    # 7.1 loader参数设置
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])  # 组合数据预处理

    # 7.2 训练数据, 可以根据需要自己修改或写
    # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
    train_data = dataset.SemData(split='train',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    # ---------------------------------------------------- END ---------------------------------------------------#
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)  # 分布式下数据loader
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:  # evaluate数据
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
        val_data = dataset.SemData(split='val',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    ## step.8 主循环
    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # 8.1 训练函数
        # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch)
        # ---------------------------------------------------- END ---------------------------------------------------#

        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
        # 8.2 保存checkpoint
        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(
                    epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        # 训练一个epoch之后evaluate
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Esempio n. 6
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
    if args.arch == 'psp':
        from model.kdnet import KDNet
        model = KDNet(layers=args.layers,
                      classes=args.classes,
                      zoom_factor=args.zoom_factor,
                      criterion=criterion,
                      temperature=args.temperature,
                      alpha=args.alpha)
        modules_ori = [
            model.student_net.layer0, model.student_net.layer1,
            model.student_net.layer2, model.student_net.layer3,
            model.student_net.layer4
        ]
        modules_new = [
            model.student_net.ppm, model.student_net.cls, model.student_net.aux
        ]
        teacher_net = model.teacher_loader
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       psa_type=args.psa_type,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.psa, model.cls, model.aux]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(
            model.cuda(), device_ids=[gpu], find_unused_parameters=True)
    else:
        model = torch.nn.DataParallel(model.cuda())

    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])
    train_data = dataset.SemData(split='train',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
        val_data = dataset.SemData(split='val',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(
                    epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Esempio n. 7
0
def main(
        config_name,
        weights_url='https://github.com/deepparrot/semseg/releases/download/0.1/pspnet50-ade20k.pth',
        weights_name='pspnet50-ade20k.pth'):

    args = config.load_cfg_from_cfg_file(config_name)
    check(args)

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    args.data_root = './.data/vision/ade20k'
    args.val_list = './.data/vision/ade20k/validation.txt'
    args.test_list = './.data/vision/ade20k/validation.txt'

    print(args.data_root)

    test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                data_list=args.test_list,
                                transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = []

    if not args.has_prediction:
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           pretrained=False)
        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           compact=args.compact,
                           shrink_factor=args.shrink_factor,
                           mask_h=args.mask_h,
                           mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor,
                           psa_softmax=args.psa_softmax,
                           pretrained=False)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

        local_checkpoint, _ = urllib.request.urlretrieve(
            weights_url, weights_name)

        if os.path.isfile(local_checkpoint):
            checkpoint = torch.load(local_checkpoint)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            raise RuntimeError(
                "=> no checkpoint found at '{}'".format(local_checkpoint))
        test(test_loader, test_data.data_list, model, args.classes, mean, std,
             args.base_size, args.test_h, args.test_w, args.scales,
             gray_folder, color_folder, colors)
    if args.split != 'test':
        cal_acc(test_data.data_list, gray_folder, args.classes, names)
Esempio n. 8
0
def main():
    global args, logger
    args = get_parser()
    # check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gen_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder.replace('ss', 'video'), 'gray')

    test_transform = transform.Compose(
        [transform.ToTensor(),
         transform.Normalize(mean=mean, std=std)])
    test_data = dataset.SemData(
        split='test',
        data_root=args.data_root,
        data_list='./data/list/cityscapes/val_video_img_sam.lst',
        transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size_gen,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')

    if not args.has_prediction:
        if args.arch == 'psp':
            from model.origin_pspnet import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           pretrained=False)
        elif args.arch == 'psp18':
            from model.pspnet_18 import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           flow=False,
                           pretrained=False)

        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           compact=args.compact,
                           shrink_factor=args.shrink_factor,
                           mask_h=args.mask_h,
                           mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor,
                           psa_softmax=args.psa_softmax,
                           pretrained=False)
        elif args.arch == 'mobile':
            from model.mobile import DenseASPP
            model = DenseASPP(layers=args.layers,
                              classes=args.classes,
                              zoom_factor=args.zoom_factor,
                              flow=False)
        elif args.arch == 'antipsp18':
            from model.antipspnet18 import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           flow=False)
        logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.ckpt_path):
            logger.info("=> loading checkpoint '{}'".format(args.ckpt_path))
            checkpoint = torch.load(args.ckpt_path)
            student_ckpt = transfer_ckpt(checkpoint)
            a, b = model.load_state_dict(student_ckpt, strict=False)
            print('unexpected keys:', a)
            print('missing keys:', b)
            logger.info("=> loaded checkpoint '{}'".format(args.ckpt_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.ckpt_path))

        test(test_loader, test_data.data_list, model, args.classes, mean, std,
             args.base_size, 1024, 2048, args.scales, gray_folder, colors)
Esempio n. 9
0
def main():
    global args, logger
    args = get_parser()
    if args.test_in_nyu_label_space:
        args.colors_path = 'nyu/nyu_colors.txt'
        args.names_path = 'nyu/nyu_names.txt'

    if args.if_cluster:
        args.data_root = args.data_root_cluster
        args.project_path = args.project_path_cluster
        args.data_config_path = 'data'
    for key in ['train_list', 'val_list', 'test_list', 'colors_path', 'names_path']:
        args[key] = os.path.join(args.data_config_path, args[key])
    for key in ['save_path', 'model_path', 'save_folder']:
        args[key] = os.path.join(args.project_path, args[key])
    # for key in ['save_path', 'model_path', 'save_folder']:
    #     args[key] = args[key] % args.exp_name

    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    transform_list_test = []
    if args.resize:
        transform_list_test.append(transform.Resize((args.resize_h_test, args.resize_w_test)))
    transform_list_test += [
        transform.Crop([args.test_h, args.test_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(), 
        transform.Normalize(mean=mean, std=std)
    ]
    test_transform = transform.Compose(transform_list_test)
    test_data = dataset.SemData(split=args.split, data_root=args.data_root, data_list=args.test_list, transform=test_transform, is_master=True, args=args)
    # test_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=test_transform, is_master=True, args=args)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step, len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = [line.rstrip('\n') for line in open(args.names_path)]

    args.read_image = test_data.read_image


    if not args.has_prediction:
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False)
        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact,
                           shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False)
        logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.model_path):
            logger.info("=> loading checkpoint '{}'".format(args.model_path))
            checkpoint = torch.load(args.model_path)
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logger.info("=> loaded checkpoint '{}'".format(args.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
        pred_path_list, target_path_list = test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, gray_folder, color_folder, colors)
    if args.split != 'test' or (args.split == 'test' and args.test_has_gt):
        cal_acc(test_data.data_list, gray_folder, args.classes, names, pred_path_list=pred_path_list, target_path_list=target_path_list)
class PSPNetSematicSegmentation(object):
    def __init__(self, config_file=CONFIG_FILE):
        # Load Parameters
        self.args_ = config.load_cfg_from_cfg_file(config_file)
        self.logger_ = get_logger()
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
            str(x) for x in self.args_.test_gpu)
        value_scale = 255
        mean = [0.485, 0.456, 0.406]
        self.mean_ = [item * value_scale for item in mean]
        std = [0.229, 0.224, 0.225]
        self.std_ = [item * value_scale for item in std]
        # self.colors_ = np.loadtxt(self.args_.colors_path).astype('uint8')

        # Load Model
        if self.args_.arch == 'psp':
            from model.pspnet import PSPNet
            self.model_ = PSPNet(layers=self.args_.layers,
                                 classes=self.args_.classes,
                                 zoom_factor=self.args_.zoom_factor,
                                 pretrained=False)
        elif self.args_.arch == 'psa':
            from model.psanet import PSANet
            self.model_ = PSANet(
                layers=self.args_.layers,
                classes=self.args_.classes,
                zoom_factor=self.args_.zoom_factor,
                compact=self.args_.compact,
                shrink_factor=self.args_.shrink_factor,
                mask_h=self.args_.mask_h,
                mask_w=self.args_.mask_w,
                normalization_factor=self.args_.normalization_factor,
                psa_softmax=self.args_.psa_softmax,
                pretrained=False)
        self.model_ = torch.nn.DataParallel(self.model_).cuda()
        cudnn.benchmark = True

        if os.path.isfile(self.args_.model_path):
            self.logger_ = get_logger().info(
                "=> loading checkpoint '{}'".format(self.args_.model_path))
            checkpoint = torch.load(self.args_.model_path)
            self.model_.load_state_dict(checkpoint['state_dict'], strict=False)
            self.logger_ = get_logger().info(
                "=> loaded checkpoint '{}'".format(self.args_.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args_.model_path))

    def get_label_colors(self, driveable=True):
        if (driveable >= 0):
            colors = [  # [  0,   0,   0],
                [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0,
                                                             0], [0, 0, 0],
                [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0,
                                                             0], [0, 0, 0],
                [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0],
                [0, 0, 0], [0, 0, 0]
            ]
            colors[driveable] = [55, 195, 55]
        else:
            colors = [  # [  0,   0,   0],
                [128, 64, 128],
                [244, 35, 232],
                [70, 70, 70],
                [55, 195, 55],
                [190, 153, 153],
                [153, 153, 153],
                [250, 170, 30],
                [220, 220, 0],
                [107, 142, 35],
                [152, 251, 152],
                [0, 130, 180],
                [220, 20, 60],
                [255, 0, 0],
                [0, 0, 142],
                [0, 0, 70],
                [0, 60, 100],
                [0, 80, 100],
                [0, 0, 230],
                [119, 11, 32],
            ]

        return dict(zip(range(19), colors))

    def net_process(self, image, flip=True):
        input = torch.from_numpy(image.transpose((2, 0, 1))).float()
        if self.std_ is None:
            for t, m in zip(input, self.mean_):
                t.sub_(m)
        else:
            for t, m, s in zip(input, self.mean_, self.std_):
                t.sub_(m).div_(s)
        input = input.unsqueeze(0).cuda()
        if flip:
            input = torch.cat([input, input.flip(3)], 0)
        with torch.no_grad():
            output = self.model_(input)
        _, _, h_i, w_i = input.shape
        _, _, h_o, w_o = output.shape
        if (h_o != h_i) or (w_o != w_i):
            output = F.interpolate(output, (h_i, w_i),
                                   mode='bilinear',
                                   align_corners=True)
        output = F.softmax(output, dim=1)
        if flip:
            output = (output[0] + output[1].flip(2)) / 2
        else:
            output = output[0]
        output = output.data.cpu().numpy()
        output = output.transpose(1, 2, 0)
        return output

    def scale_process(self,
                      model,
                      image,
                      classes,
                      crop_h,
                      crop_w,
                      h,
                      w,
                      mean,
                      std=None,
                      stride_rate=2 / 3):
        ori_h, ori_w, _ = image.shape
        pad_h = max(crop_h - ori_h, 0)
        pad_w = max(crop_w - ori_w, 0)
        pad_h_half = int(pad_h / 2)
        pad_w_half = int(pad_w / 2)
        if pad_h > 0 or pad_w > 0:
            image = cv2.copyMakeBorder(image,
                                       pad_h_half,
                                       pad_h - pad_h_half,
                                       pad_w_half,
                                       pad_w - pad_w_half,
                                       cv2.BORDER_CONSTANT,
                                       value=mean)
        new_h, new_w, _ = image.shape
        stride_h = int(np.ceil(crop_h * stride_rate))
        stride_w = int(np.ceil(crop_w * stride_rate))
        grid_h = int(np.ceil(float(new_h - crop_h) / stride_h) + 1)
        grid_w = int(np.ceil(float(new_w - crop_w) / stride_w) + 1)
        prediction_crop = np.zeros((new_h, new_w, classes), dtype=float)
        count_crop = np.zeros((new_h, new_w), dtype=float)
        for index_h in range(0, grid_h):
            for index_w in range(0, grid_w):
                s_h = index_h * stride_h
                e_h = min(s_h + crop_h, new_h)
                s_h = e_h - crop_h
                s_w = index_w * stride_w
                e_w = min(s_w + crop_w, new_w)
                s_w = e_w - crop_w
                image_crop = image[s_h:e_h, s_w:e_w].copy()
                count_crop[s_h:e_h, s_w:e_w] += 1
                prediction_crop[s_h:e_h,
                                s_w:e_w, :] += self.net_process(image_crop)
        prediction_crop /= np.expand_dims(count_crop, 2)
        prediction_crop = prediction_crop[pad_h_half:pad_h_half + ori_h,
                                          pad_w_half:pad_w_half + ori_w]
        prediction = cv2.resize(prediction_crop, (w, h),
                                interpolation=cv2.INTER_LINEAR)
        return prediction

    def test(self, model, image, classes, base_size, crop_h, crop_w, scales):
        image = cv2.cvtColor(
            image, cv2.COLOR_BGR2RGB
        )  # convert cv2 read image from BGR order to RGB order
        h, w, _ = image.shape
        prediction = np.zeros((h, w, classes), dtype=float)
        for scale in scales:
            long_size = round(scale * base_size)
            new_h = long_size
            new_w = long_size
            if h > w:
                new_w = round(long_size / float(h) * w)
            else:
                new_h = round(long_size / float(w) * h)
            image_scale = cv2.resize(image, (new_w, new_h),
                                     interpolation=cv2.INTER_LINEAR)
            prediction += self.scale_process(model, image_scale, classes,
                                             crop_h, crop_w, h, w, self.mean_,
                                             self.std_)
        prediction = self.scale_process(model, image_scale, classes, crop_h,
                                        crop_w, h, w, self.mean_, self.std_)
        prediction = np.argmax(prediction, axis=2)
        gray = np.uint8(prediction)
        # color = colorize(gray, colors)
        return gray

    def colorize(self, labels, label_colors):
        labels = np.clip(labels, 0, len(label_colors) - 1)
        r = labels.copy()
        g = labels.copy()
        b = labels.copy()
        for l in range(0, 19):
            r[labels == l] = label_colors[l][0]
            g[labels == l] = label_colors[l][1]
            b[labels == l] = label_colors[l][2]
        rgb = np.zeros((labels.shape[0], labels.shape[1], 3))
        rgb[:, :, 0] = r
        rgb[:, :, 1] = g
        rgb[:, :, 2] = b
        return np.uint8(rgb)

    def process_img_driveable(self, img, size, drivable_idx=-1):
        img_resized = cv2.resize(
            img, (int(size[1]), int(size[0])))  # uint8 with RGB mode
        labeled_img = self.test(self.model_.eval(), img_resized,
                                self.args_.classes, self.args_.base_size,
                                self.args_.test_h, self.args_.test_w,
                                self.args_.scales)
        segmented_img = self.colorize(labeled_img,
                                      self.get_label_colors(driveable=-1))
        drivable_img = self.colorize(
            labeled_img, self.get_label_colors(driveable=drivable_idx))

        return segmented_img, drivable_img
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.sync_bn:
        if args.multiprocessing_distributed:
            BatchNorm = apex.parallel.SyncBatchNorm
        else:
            from lib.sync_bn.modules import BatchNorm2d
            BatchNorm = BatchNorm2d
    else:
        BatchNorm = nn.BatchNorm2d
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size,
                                rank=args.rank)
    from util.crit import CriterionDSN
    criterion = CriterionDSN(ignore_index=args.ignore_label)
    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion,
                       BatchNorm=BatchNorm)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'psp18':
        from model.pspnet_18 import PSPNet
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion,
                       BatchNorm=BatchNorm, flow=True)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'antipsp18':
        from model.antipspnet18 import PSPNet
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion,
                       BatchNorm=BatchNorm, flow=True)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type,
                       compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax,
                       criterion=criterion,
                       BatchNorm=BatchNorm)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.psa, model.cls, model.aux]
    elif args.arch == 'mobile':
        from model.mobile import DenseASPP
        model = DenseASPP(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion,
                          BatchNorm=BatchNorm, flow=True)
        modules_ori = [model.features]
        modules_new = [model.ppm, model.cls, model.aux]
    params_list = []
    args.index_split = 5
    if 'mobile' in args.arch:
        args.index_split = 1

    if args.tune_weight:
        student_ckpt = torch.load(args.tune_weight,
                                  map_location=torch.device('cpu'))
        new_params = model.state_dict().copy()
        if 'state_dict' in student_ckpt:
            student_ckpt = student_ckpt['state_dict']
        for i in student_ckpt:
            if 'module' in i or 'student' in i:
                new_params[i.replace('module.', '').replace('student.', '')] = student_ckpt[i]
            else:
                new_params[i] = student_ckpt[i]
            # print('load:', i_parts)
        a, b = model.load_state_dict(new_params, strict=False)
        print('missing keys:', a)
        del new_params
        del student_ckpt
        for module in modules_ori:
            params_list.append(dict(params=module.parameters(), lr=args.base_lr))
        for module in modules_new:
            params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    else:
        for module in modules_ori:
            params_list.append(dict(params=module.parameters(), lr=args.base_lr))
        for module in modules_new:
            params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10))


    optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)

    criterion_list = nn.ModuleDict({'ce': criterion})

    from model.flow_model import FlowModel

    model = FlowModel(model, criterion_list, args)
    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(args.workers / ngpus_per_node)
        if args.use_apex:
            model, optimizer = apex.amp.initialize(model.cuda(), optimizer, opt_level=args.opt_level,
                                                   keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                                   loss_scale=args.loss_scale)
            model = apex.parallel.DistributedDataParallel(model)
        else:
            model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu])

    else:
        model = torch.nn.DataParallel(model.cuda())



    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(args.resume))
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)])
    train_data = dataset.VideoData(split='train', data_root=args.data_root, data_list=args.train_list,
                                   transform=train_transform,frame_gap=args.frame_gap,random_frame=args.random_frame)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None),
                                               num_workers=args.workers, pin_memory=True, sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)])
        val_data = dataset.VideoData(split='val', data_root=args.data_root, data_list=args.val_list,
                                     transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save({'epoch': epoch_log, 'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()},
                       filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Esempio n. 12
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       criterion=criterion,
                       args=args)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       psa_type=args.psa_type,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.psa, model.cls, model.aux]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
    else:
        logger = None
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
                                                          device_ids=[gpu])
    else:
        model = torch.nn.DataParallel(model.cuda())

    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume != 'none':
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            # model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # print(checkpoint['optimizer'].keys())
            if args.if_remove_cls:
                if main_process():
                    logger.info(
                        '=====!!!!!!!===== Remove cls layer in resuming...')
                checkpoint['state_dict'] = {
                    x: checkpoint['state_dict'][x]
                    for x in checkpoint['state_dict'].keys()
                    if ('module.cls' not in x and 'module.aux' not in x)
                }
                # checkpoint['optimizer'] = {x: checkpoint['optimizer'][x] for x in checkpoint['optimizer'].keys() if ('module.cls' not in x and 'module.aux' not in x)}
                # if main_process():
                #     print('----', checkpoint['state_dict'].keys())
                #     print('----', checkpoint['optimizer'].keys())
                #     print('----1', checkpoint['optimizer']['state'].keys())

            model.load_state_dict(checkpoint['state_dict'], strict=False)
            if not args.if_remove_cls:
                optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    transform_list_train = []
    if args.resize:
        transform_list_train.append(
            transform.Resize((args.resize_h, args.resize_w)))
    transform_list_train += [
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ]
    train_transform = transform.Compose(transform_list_train)
    train_data = dataset.SemData(split='val',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform,
                                 logger=logger,
                                 is_master=main_process(),
                                 args=args)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    if args.evaluate:
        transform_list_val = []
        if args.resize:
            transform_list_val.append(
                transform.Resize((args.resize_h, args.resize_w)))
        transform_list_val += [
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ]
        val_transform = transform.Compose(transform_list_val)
        val_data = dataset.SemData(split='val',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform,
                                   is_master=main_process(),
                                   args=args)
        args.read_image = val_data.read_image
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1

        # if args.evaluate and args.val_every_iter == -1:
        #     # logger.info('Validating.....')
        #     loss_val, mIoU_val, mAcc_val, allAcc_val, return_dict = validate(val_loader, model, criterion, args)
        #     if main_process():
        #         writer.add_scalar('VAL/loss_val', loss_val, epoch_log)
        #         writer.add_scalar('VAL/mIoU_val', mIoU_val, epoch_log)
        #         writer.add_scalar('VAL/mAcc_val', mAcc_val, epoch_log)
        #         writer.add_scalar('VAL/allAcc_val', allAcc_val, epoch_log)

        #         for sample_idx in range(len(return_dict['image_name_list'])):
        #             writer.add_text('VAL-image_name/%d'%sample_idx, return_dict['image_name_list'][sample_idx], epoch)
        #             writer.add_image('VAL-image/%d'%sample_idx, return_dict['im_list'][sample_idx], epoch, dataformats='HWC')
        #             writer.add_image('VAL-color_label/%d'%sample_idx, return_dict['color_GT_list'][sample_idx], epoch, dataformats='HWC')
        #             writer.add_image('VAL-color_pred/%d'%sample_idx, return_dict['color_pred_list'][sample_idx], epoch, dataformats='HWC')

        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch, epoch_log, val_loader,
            criterion)
        if main_process():
            writer.add_scalar('TRAIN/loss_train', loss_train, epoch_log)
            writer.add_scalar('TRAIN/mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('TRAIN/mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('TRAIN/allAcc_train', allAcc_train, epoch_log)