Exemplo n.º 1
0
def main(config_file, exp_suffix):
    # LOAD ARGS
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)
    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TEST.SNAPSHOT_DIR[0] == '':
        cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT,
                                            cfg.EXP_NAME)
        os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True)

    print('Using config:')
    pprint.pprint(cfg)
    # load models
    models = []
    n_models = len(cfg.TEST.MODEL)
    if cfg.TEST.MODE == 'best':
        assert n_models == 1, 'Not yet supported'
    for i in range(n_models):
        if cfg.TEST.MODEL[i] == 'DeepLabv2':
            model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                                   multi_level=cfg.TEST.MULTI_LEVEL[i])
        else:
            raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}")
        models.append(model)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    if cfg.TARGET == 'Mapillary':
        test_dataset = MapillaryDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            set=cfg.TEST.SET_TARGET,
            crop_size=cfg.TEST.INPUT_SIZE_TARGET,
            mean=cfg.TEST.IMG_MEAN,
            labels_size=cfg.TEST.OUTPUT_SIZE_TARGET,
            scale_label=False)
    else:
        test_dataset = CityscapesDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            list_path=cfg.DATA_LIST_TARGET,
            set=cfg.TEST.SET_TARGET,
            info_path=cfg.TEST.INFO_TARGET,
            crop_size=cfg.TEST.INPUT_SIZE_TARGET,
            mean=cfg.TEST.IMG_MEAN,
            labels_size=cfg.TEST.OUTPUT_SIZE_TARGET,
            num_classes=cfg.NUM_CLASSES)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                  num_workers=cfg.NUM_WORKERS,
                                  shuffle=False,
                                  pin_memory=True)
    # eval
    evaluate_domain_adaptation(models, test_loader, cfg)
Exemplo n.º 2
0
def main():

    # LOAD ARGS
    args = get_arguments()
    config_file = args.cfg
    exp_suffix = args.exp_suffix
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)
    # auto-generate exp name if not specified
    # pdb.set_trace()

    if cfg.EXP_NAME == '':
        if args.MBT:  # when to train a model on pseudo label from MBT
            cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_MBT_THRESH_{args.thres}_ROUND_{args.round}'
        else:
            args.LB = str(args.LB).replace('.', '_')
            cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}_THRESH_{args.thres}_ROUND_{args.round}'
    # ----------------------------------------------------------------#
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'
    # auto-generate snapshot path if not specified
    # pdb.set_trace()
    if cfg.TEST.SNAPSHOT_DIR[0] == '':
        cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT,
                                            cfg.EXP_NAME)
        os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True)

    print('Using config:')
    pprint.pprint(cfg)
    # load models
    models = []
    n_models = len(cfg.TEST.MODEL)
    if cfg.TEST.MODE == 'best':
        assert n_models == 1, 'Not yet supported'
    for i in range(n_models):
        if cfg.TEST.MODEL[i] == 'DeepLabv2':
            model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                                   multi_level=cfg.TEST.MULTI_LEVEL[i])
        else:
            raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}")
        models.append(model)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    # pdb.set_trace()
    # ----------------------------------------------------------------#
    test_dataset = CityscapesDataSet(
        args=args,
        root=cfg.DATA_DIRECTORY_TARGET,
        list_path='../ADVENT/advent/dataset/cityscapes_list/{}.txt',
        set=cfg.TEST.SET_TARGET,
        info_path=cfg.TEST.INFO_TARGET,
        crop_size=cfg.TEST.INPUT_SIZE_TARGET,
        mean=cfg.TEST.IMG_MEAN,
        labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                  num_workers=cfg.NUM_WORKERS,
                                  shuffle=False,
                                  pin_memory=True)
    # eval
    # pdb.set_trace()
    evaluate_domain_adaptation(models, test_loader, cfg)
Exemplo n.º 3
0
def main(config_file, exp_suffix):
    # LOAD ARGS
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)

    cfg.NUM_WORKERS = args.num_workers

    ### dataset settings
    cfg.SOURCE = args.source
    cfg.TARGET = args.target
    ## source config
    if cfg.SOURCE == 'GTA':
        cfg.DATA_LIST_SOURCE = str(project_root / 'advent/dataset/gta5_list/{}.txt')
        cfg.DATA_DIRECTORY_SOURCE = str(project_root / 'data/GTA5')

    elif cfg.SOURCE == 'SYNTHIA':
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")
    else:
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")

    ## target config
    if cfg.TARGET == 'Cityscapes':
        cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/cityscapes_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/cityscapes')
        cfg.EXP_ROOT = project_root / 'experiments_G2C'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots_G2C')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs_G2C')
        cfg.TEST.INPUT_SIZE_TARGET = (1024, 512)
        cfg.TEST.OUTPUT_SIZE_TARGET = (2048, 1024)
        cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/cityscapes_list/info.json')

    elif cfg.TARGET == 'BDD':
        cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/compound_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/bdd/Compound')
        cfg.EXP_ROOT = project_root / 'experiments'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs')
        cfg.TEST.INPUT_SIZE_TARGET = (960, 540)
        cfg.TEST.OUTPUT_SIZE_TARGET = (1280, 720)
        cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/compound_list/info.json')

    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} dataset")


    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{cfg.TRAIN.OCDA_METHOD}'
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TEST.SNAPSHOT_DIR[0] == '':
        cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
        os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True)

    print('Using config:')
    pprint.pprint(cfg)
    # load models
    models = []
    n_models = len(cfg.TEST.MODEL)
    if cfg.TEST.MODE == 'best':
        assert n_models == 1, 'Not yet supported'
    for i in range(n_models):
        if cfg.TEST.MODEL[i] == 'DeepLabv2':
            model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                                   multi_level=cfg.TEST.MULTI_LEVEL[i])

        elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG':
            model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL)
        else:
            raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}")
        models.append(model)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    if cfg.TARGET == 'Cityscapes':
        test_dataset = CityscapesDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                         list_path=cfg.DATA_LIST_TARGET,
                                         set=cfg.TEST.SET_TARGET,
                                         info_path=cfg.TEST.INFO_TARGET,
                                         crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                         mean=cfg.TEST.IMG_MEAN,
                                         labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
        test_loader = data.DataLoader(test_dataset,
                                      batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                      num_workers=cfg.NUM_WORKERS,
                                      shuffle=False,
                                      pin_memory=True)
    elif cfg.TARGET == 'BDD':
        test_dataset = BDDdataset(root=cfg.DATA_DIRECTORY_TARGET,
                                           list_path=cfg.DATA_LIST_TARGET,
                                           set=cfg.TEST.SET_TARGET,
                                           info_path=cfg.TEST.INFO_TARGET,
                                           crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                           mean=cfg.TEST.IMG_MEAN,
                                           labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
        test_loader = data.DataLoader(test_dataset,
                                        batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                        num_workers=cfg.NUM_WORKERS,
                                        shuffle=False,
                                        pin_memory=True)
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} datasets")
    # eval
    evaluate_domain_adaptation(models, test_loader, cfg)