예제 #1
0
def get_parser():
    parser = argparse.ArgumentParser(
        description='PyTorch Semantic Segmentation')
    parser.add_argument('--config',
                        type=str,
                        default='config/ade20k/ade20k_pspnet50.yaml',
                        help='config file')
    parser.add_argument(
        'opts',
        help='see config/ade20k/ade20k_pspnet50.yaml for all options',
        default=None,
        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    assert args.config is not None
    cfg = config.load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    check(cfg)

    # backup the config file
    if os.path.exists(cfg.save_path):
        shutil.rmtree(cfg.save_path)

    os.makedirs(cfg.save_path)
    shutil.copyfile(args.config, os.path.join(cfg.save_path, "config.yaml"))
    return cfg
def get_parser():
    parser = argparse.ArgumentParser(description='Few-Shot Semantic Segmentation')
    parser.add_argument('--config', type=str, default='config/MSD/fold0_resnet50.yaml', help='config file')

    args = parser.parse_args()
    cfg = config.load_cfg_from_cfg_file(args.config)

    return cfg
예제 #3
0
def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch Semantic Segmentation')
    parser.add_argument('--config', type=str, default='config/ade20k/ade20k_pspnet50.yaml', help='config file')
    parser.add_argument('opts', help='see config/ade20k/ade20k_pspnet50.yaml for all options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()
    assert args.config is not None
    cfg = config.load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    return cfg
예제 #4
0
def get_parser():
    #parser = argparse.ArgumentParser(description='PyTorch Semantic Segmentation')
    #parser.add_argument('--config', type=str, default='/home/agalex/PycharmProjects/torch_1_3_0/segmentation/config/voc2012/voc2012_deeplabv3_resnet101.yaml', help='config file')
    #parser.add_argument('opts', help='see /home/agalex/PycharmProjects/torch_1_3_0/segmentation/config/voc2012/voc2012_pspnet101.yaml for all options', default=None, nargs=argparse.REMAINDER)
    #args = parser.parse_args()
    #assert args.config is not None
    cfg = config.load_cfg_from_cfg_file('/content/drive/My Drive/Segmentation/config/voc2012/voc2012_deeplabv3_resnet101.yaml')
    #if args.opts is not None:
    #    cfg = config.merge_cfg_from_list(cfg, args.opts)
    return cfg
예제 #5
0
def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch Semantic Segmentation')
    parser.add_argument('--config', type=str, default='config/cityscapes/cityscapes_pspnet18.yaml', help='config file')
    parser.add_argument('--image', type=str, default='figure/demo/cityscapes.jpg', help='input image')
    parser.add_argument('opts', help='see config/cityscapes/cityscapes_pspnet18.yaml for all options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()
    assert args.config is not None
    cfg = config.load_cfg_from_cfg_file(args.config)
    cfg.image = args.image
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    return cfg
    def __init__(self, config_file=CONFIG_FILE):
        # Load Parameters
        self.args_ = config.load_cfg_from_cfg_file(config_file)
        self.logger_ = get_logger()
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
            str(x) for x in self.args_.test_gpu)
        value_scale = 255
        mean = [0.485, 0.456, 0.406]
        self.mean_ = [item * value_scale for item in mean]
        std = [0.229, 0.224, 0.225]
        self.std_ = [item * value_scale for item in std]
        # self.colors_ = np.loadtxt(self.args_.colors_path).astype('uint8')

        # Load Model
        if self.args_.arch == 'psp':
            from model.pspnet import PSPNet
            self.model_ = PSPNet(layers=self.args_.layers,
                                 classes=self.args_.classes,
                                 zoom_factor=self.args_.zoom_factor,
                                 pretrained=False)
        elif self.args_.arch == 'psa':
            from model.psanet import PSANet
            self.model_ = PSANet(
                layers=self.args_.layers,
                classes=self.args_.classes,
                zoom_factor=self.args_.zoom_factor,
                compact=self.args_.compact,
                shrink_factor=self.args_.shrink_factor,
                mask_h=self.args_.mask_h,
                mask_w=self.args_.mask_w,
                normalization_factor=self.args_.normalization_factor,
                psa_softmax=self.args_.psa_softmax,
                pretrained=False)
        self.model_ = torch.nn.DataParallel(self.model_).cuda()
        cudnn.benchmark = True

        if os.path.isfile(self.args_.model_path):
            self.logger_ = get_logger().info(
                "=> loading checkpoint '{}'".format(self.args_.model_path))
            checkpoint = torch.load(self.args_.model_path)
            self.model_.load_state_dict(checkpoint['state_dict'], strict=False)
            self.logger_ = get_logger().info(
                "=> loaded checkpoint '{}'".format(self.args_.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args_.model_path))
예제 #7
0
def get_parser():
    parser = argparse.ArgumentParser(
        description='PyTorch Point Cloud Classification / Semantic Segmentation'
    )
    # parser.add_argument('--config', type=str, default='config/s3dis/s3dis_pointweb.yaml', help='config file')
    parser.add_argument('--config',
                        type=str,
                        default='config/modelnet40/modelnet40_pointweb.yaml',
                        help='config file')
    parser.add_argument(
        'opts',
        help='see config/s3dis/s3dis_pointweb.yaml for all options',
        default=None,
        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    assert args.config is not None
    cfg = config.load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    return cfg
예제 #8
0
def get_parser():
    parser = argparse.ArgumentParser(
        description='PyTorch Semantic Segmentation')  # 创建 ArgumentParser() 对象
    # 添加参数:--表示可选参数,否则为定位参数
    parser.add_argument(
        '--config',
        type=str,
        default='/data/ss/URISC/CODE/FullVersion/config/URISC/urisc_unet.yaml',
        help='config file')
    parser.add_argument(
        'opts',
        help='see config/URISC/urisc_unet.yaml for all options',
        default=None,
        nargs=argparse.REMAINDER)
    # 解析添加的参数
    args = parser.parse_args()
    assert args.config is not None
    # 将yaml文件加载为CfgNode
    cfg = config.load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    return cfg
예제 #9
0
def main(
        config_name,
        weights_url='https://github.com/deepparrot/semseg/releases/download/0.1/pspnet50-ade20k.pth',
        weights_name='pspnet50-ade20k.pth'):

    args = config.load_cfg_from_cfg_file(config_name)
    check(args)

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    args.data_root = './.data/vision/ade20k'
    args.val_list = './.data/vision/ade20k/validation.txt'
    args.test_list = './.data/vision/ade20k/validation.txt'

    print(args.data_root)

    test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                data_list=args.test_list,
                                transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = []

    if not args.has_prediction:
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           pretrained=False)
        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           compact=args.compact,
                           shrink_factor=args.shrink_factor,
                           mask_h=args.mask_h,
                           mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor,
                           psa_softmax=args.psa_softmax,
                           pretrained=False)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

        local_checkpoint, _ = urllib.request.urlretrieve(
            weights_url, weights_name)

        if os.path.isfile(local_checkpoint):
            checkpoint = torch.load(local_checkpoint)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            raise RuntimeError(
                "=> no checkpoint found at '{}'".format(local_checkpoint))
        test(test_loader, test_data.data_list, model, args.classes, mean, std,
             args.base_size, args.test_h, args.test_w, args.scales,
             gray_folder, color_folder, colors)
    if args.split != 'test':
        cal_acc(test_data.data_list, gray_folder, args.classes, names)
예제 #10
0
파일: sotabench.py 프로젝트: Randl/SAN
# Model 1
config_path = 'config/imagenet/imagenet_san10_pairwise.yaml'
file_id = '1lv5TYfJFYvNWt_Ik0E-nAuI5h4PqSwuk'
destination = './tmp/'
filename = 'imagenet_san10_pairwise.pth'
download_file_from_google_drive(file_id, destination, filename=filename)
checkpoint = torch.load(os.path.join(destination, filename))
sd = {}
for key in checkpoint['state_dict']:
    sd[key.replace('module.', '')] = checkpoint['state_dict'][key]
# Define the transforms need to convert ImageNet data to expected model input
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
input_transform = transforms.Compose(
    [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std)])

args = config.load_cfg_from_cfg_file(config_path)
model = san(args.sa_type, args.layers, args.kernels, args.classes)
model.load_state_dict(sd)

# Run the benchmark
ImageNet.benchmark(
    model=model,
    paper_model_name='SAN10-pairwise',
    paper_arxiv_id='2004.13621',
    input_transform=input_transform,
    batch_size=256,
    num_gpu=1,
    paper_results={'Top 1 Accuracy': 0.749, 'Top 5 Accuracy': 0.921},
    model_description="Official weights from the authors of the paper.",
)
torch.cuda.empty_cache()
def get_parser():
    cfg = config.load_cfg_from_cfg_file(CONFIG_FILE)
    return cfg