コード例 #1
0
def main(config_file, exp_suffix):
    # LOAD ARGS
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)
    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TEST.SNAPSHOT_DIR[0] == '':
        cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
        os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True)

    print('Using config:')
    pprint.pprint(cfg)
    # load models
    models = []
    n_models = len(cfg.TEST.MODEL)
    if cfg.TEST.MODE == 'best':
        assert n_models == 1, 'Not yet supported'
    for i in range(n_models):
        if cfg.TEST.MODEL[i] == 'DeepLabv2':
            model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                                   multi_level=cfg.TEST.MULTI_LEVEL[i])
        elif cfg.TEST.MODEL[i] == 'DeepLabv2_VGG':
            model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL)

        else:
            raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}")
        models.append(model)



    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    test_dataset = BDDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                     list_path=cfg.DATA_LIST_TARGET,
                                     set=cfg.TEST.SET_TARGET,
                                     info_path=cfg.TEST.INFO_TARGET,
                                     crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                     mean=cfg.TEST.IMG_MEAN,
                                     labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                  num_workers=cfg.NUM_WORKERS,
                                  shuffle=False,
                                  pin_memory=True)
    # eval
    evaluate_domain_adaptation(models, test_loader, cfg)
コード例 #2
0
ファイル: test.py プロジェクト: gritYCDA/boundaryOCDA
def main(config_file, exp_suffix):
    # LOAD ARGS
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)

    cfg.NUM_WORKERS = args.num_workers

    ### dataset settings
    cfg.SOURCE = args.source
    cfg.TARGET = args.target
    ## source config
    if cfg.SOURCE == 'GTA':
        cfg.DATA_LIST_SOURCE = str(project_root / 'advent/dataset/gta5_list/{}.txt')
        cfg.DATA_DIRECTORY_SOURCE = str(project_root / 'data/GTA5')

    elif cfg.SOURCE == 'SYNTHIA':
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")
    else:
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")

    ## target config
    if cfg.TARGET == 'Cityscapes':
        cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/cityscapes_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/cityscapes')
        cfg.EXP_ROOT = project_root / 'experiments_G2C'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots_G2C')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs_G2C')
        cfg.TEST.INPUT_SIZE_TARGET = (1024, 512)
        cfg.TEST.OUTPUT_SIZE_TARGET = (2048, 1024)
        cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/cityscapes_list/info.json')

    elif cfg.TARGET == 'BDD':
        cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/compound_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/bdd/Compound')
        cfg.EXP_ROOT = project_root / 'experiments'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs')
        cfg.TEST.INPUT_SIZE_TARGET = (960, 540)
        cfg.TEST.OUTPUT_SIZE_TARGET = (1280, 720)
        cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/compound_list/info.json')

    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} dataset")


    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{cfg.TRAIN.OCDA_METHOD}'
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TEST.SNAPSHOT_DIR[0] == '':
        cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
        os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True)

    print('Using config:')
    pprint.pprint(cfg)
    # load models
    models = []
    n_models = len(cfg.TEST.MODEL)
    if cfg.TEST.MODE == 'best':
        assert n_models == 1, 'Not yet supported'
    for i in range(n_models):
        if cfg.TEST.MODEL[i] == 'DeepLabv2':
            model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                                   multi_level=cfg.TEST.MULTI_LEVEL[i])

        elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG':
            model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL)
        else:
            raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}")
        models.append(model)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    if cfg.TARGET == 'Cityscapes':
        test_dataset = CityscapesDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                         list_path=cfg.DATA_LIST_TARGET,
                                         set=cfg.TEST.SET_TARGET,
                                         info_path=cfg.TEST.INFO_TARGET,
                                         crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                         mean=cfg.TEST.IMG_MEAN,
                                         labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
        test_loader = data.DataLoader(test_dataset,
                                      batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                      num_workers=cfg.NUM_WORKERS,
                                      shuffle=False,
                                      pin_memory=True)
    elif cfg.TARGET == 'BDD':
        test_dataset = BDDdataset(root=cfg.DATA_DIRECTORY_TARGET,
                                           list_path=cfg.DATA_LIST_TARGET,
                                           set=cfg.TEST.SET_TARGET,
                                           info_path=cfg.TEST.INFO_TARGET,
                                           crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                           mean=cfg.TEST.IMG_MEAN,
                                           labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
        test_loader = data.DataLoader(test_dataset,
                                        batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                        num_workers=cfg.NUM_WORKERS,
                                        shuffle=False,
                                        pin_memory=True)
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} datasets")
    # eval
    evaluate_domain_adaptation(models, test_loader, cfg)
コード例 #3
0
def main():
    # LOAD ARGS
    args = get_arguments()
    print('Called with args:')
    print(args)

    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)
    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'

    if args.exp_suffix:
        cfg.EXP_NAME += f'_{args.exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TRAIN.SNAPSHOT_DIR == '':
        cfg.TRAIN.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR, exist_ok=True)
    # tensorboard
    if args.tensorboard:
        if cfg.TRAIN.TENSORBOARD_LOGDIR == '':
            cfg.TRAIN.TENSORBOARD_LOGDIR = osp.join(cfg.EXP_ROOT_LOGS, 'tensorboard', cfg.EXP_NAME)
        os.makedirs(cfg.TRAIN.TENSORBOARD_LOGDIR, exist_ok=True)
        if args.viz_every_iter is not None:
            cfg.TRAIN.TENSORBOARD_VIZRATE = args.viz_every_iter
    else:
        cfg.TRAIN.TENSORBOARD_LOGDIR = ''
    print('Using config:')
    pprint.pprint(cfg)

    shuffle = cfg.TRAIN.SHUFFLE

    # INIT
    _init_fn = None
    if not args.random_train:
        torch.manual_seed(cfg.TRAIN.RANDOM_SEED)
        torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED)
        np.random.seed(cfg.TRAIN.RANDOM_SEED)
        random.seed(cfg.TRAIN.RANDOM_SEED)

        def _init_fn(worker_id):
            np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # LOAD SEGMENTATION NET
    # assert osp.exists(cfg.TRAIN.RESTORE_FROM), f'Missing init model {cfg.TRAIN.RESTORE_FROM}'
    if cfg.TRAIN.MODEL == 'DeepLabv2':
        model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES, multi_level=cfg.TRAIN.MULTI_LEVEL)
        saved_state_dict = torch.load(cfg.TRAIN.RESTORE_FROM)
        if 'DeepLab_resnet_pretrained_imagenet' in cfg.TRAIN.RESTORE_FROM:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)
        else:
            model.load_state_dict(saved_state_dict)
    elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG':
        model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL)

        if cfg.TRAIN.SELF_TRAINING:
            path = osp.join(cfg.TRAIN.RESTORE_FROM_SELF, 'model_46000.pth')
            saved_state_dict = torch.load(path)
            model.load_state_dict(saved_state_dict, strict=False)    
            trg_list = cfg.DATA_LIST_TARGET_ORDER
        else:
            trg_list = cfg.DATA_LIST_TARGET
                
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TRAIN.MODEL}")

    print('Model loaded')

    # DATALOADERS
    source_dataset = GTA5DataSet(root=cfg.DATA_DIRECTORY_SOURCE,
                                 list_path=cfg.DATA_LIST_SOURCE,
                                 set=cfg.TRAIN.SET_SOURCE,
                                 # max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_SOURCE,
                                 max_iters=None,
                                 crop_size=cfg.TRAIN.INPUT_SIZE_SOURCE,
                                 mean=cfg.TRAIN.IMG_MEAN)

    source_loader = data.DataLoader(source_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_SOURCE,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=_init_fn)

    target_dataset = BDDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=trg_list,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       # max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_TARGET,
                                       max_iters=None,
                                       crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=shuffle,
                                    pin_memory=True,
                                    worker_init_fn=_init_fn)

    with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, 'train_cfg.yml'), 'w') as yaml_file:
        yaml.dump(cfg, yaml_file, default_flow_style=False)

    # UDA TRAINING
    train_domain_adaptation(model, source_loader, target_loader, cfg)
コード例 #4
0
def main():
    # LOAD ARGS
    args = get_arguments()
    print('Called with args:')
    print(args)

    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)
    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'

    if args.exp_suffix:
        cfg.EXP_NAME += f'_{args.exp_suffix}'
    # auto-generate snapshot path if not specified
    # if cfg.TEST.SNAPSHOT_DIR == '':
    cfg.TEST.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
    os.makedirs(cfg.TEST.SNAPSHOT_DIR, exist_ok=True)

    num_classes = cfg.NUM_CLASSES
    device = cfg.GPU_ID

    output_path = osp.join(cfg.TEST.SNAPSHOT_DIR, 'compound_order')
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    print('Using config:')
    pprint.pprint(cfg)

    # INIT
    _init_fn = None
    if not args.random_train:
        torch.manual_seed(cfg.TRAIN.RANDOM_SEED)
        torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED)
        np.random.seed(cfg.TRAIN.RANDOM_SEED)
        random.seed(cfg.TRAIN.RANDOM_SEED)

        def _init_fn(worker_id):
            np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    i_iter = 46000
    
    #### Load Discriminator model #####    
    restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR, f'model_{i_iter}.pth')
    model_seg = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL)
    load_checkpoint_for_evaluation(model_seg, restore_from, device)

    #### Load Discriminator model #####
    # restore_from_dis = osp.join(cfg.TEST.SNAPSHOT_DIR, f'model_{i_iter}_D2.pth')
    # model_dis = get_fc_discriminator(num_classes=num_classes)
    # load_checkpoint_for_evaluation(model_dis, restore_from_dis, device)

    print('Model loaded')


    target_dataset = BDDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=cfg.DATA_LIST_TARGET,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=_init_fn)

    with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, 'train_cfg.yml'), 'w') as yaml_file:
        yaml.dump(cfg, yaml_file, default_flow_style=False)

    # UDA TRAINING
    # ranking_target_w_discrim(model_seg, model_dis, target_loader, output_path, cfg)
    ranking_target_w_discrim(model_seg, target_loader, output_path, cfg)
コード例 #5
0
ファイル: train.py プロジェクト: gritYCDA/boundaryOCDA
def main():
    # LOAD ARGS
    args = get_arguments()
    print('Called with args:')
    print(args)

    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)
    cfg.NUM_WORKERS = args.num_workers
    if args.option is not None:
        cfg.TRAIN.OPTION = args.option
    cfg.TRAIN.LAMBDA_BOUNDARY = args.LAMBDA_BOUNDARY
    cfg.TRAIN.LAMBDA_DICE = args.LAMBDA_DICE

    ## gan method settings
    cfg.GAN = args.gan
    if cfg.GAN == 'gan':
        cfg.TRAIN.LAMBDA_ADV_MAIN = 0.001  # GAN
    elif cfg.GAN == 'lsgan':
        cfg.TRAIN.LAMBDA_ADV_MAIN = 0.01  # LS-GAN
    else:
        raise NotImplementedError(f"Not Supported gan method")

    ### dataset settings
    cfg.SOURCE = args.source
    cfg.TARGET = args.target
    ## source config
    if cfg.SOURCE == 'GTA':
        cfg.DATA_LIST_SOURCE = str(project_root /
                                   'advent/dataset/gta5_list/{}.txt')
        cfg.DATA_DIRECTORY_SOURCE = str(project_root / 'data/GTA5')
        cfg.TRAIN.INPUT_SIZE_SOURCE = (1280, 720)

    elif cfg.SOURCE == 'SYNTHIA':
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")
    else:
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")

    ## target config
    if cfg.TARGET == 'Cityscapes':
        cfg.DATA_LIST_TARGET = str(project_root /
                                   'advent/dataset/cityscapes_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/cityscapes')
        cfg.EXP_ROOT = project_root / 'experiments_G2C'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots_G2C')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs_G2C')
        cfg.TRAIN.INPUT_SIZE_TARGET = (1024, 512)
        cfg.TRAIN.INFO_TARGET = str(project_root /
                                    'advent/dataset/cityscapes_list/info.json')

        cfg.TEST.INPUT_SIZE_TARGET = (1024, 512)
        cfg.TEST.OUTPUT_SIZE_TARGET = (2048, 1024)
        cfg.TEST.INFO_TARGET = str(project_root /
                                   'advent/dataset/cityscapes_list/info.json')

    elif cfg.TARGET == 'BDD':
        cfg.DATA_LIST_TARGET = str(project_root /
                                   'advent/dataset/compound_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/bdd/Compound')
        cfg.EXP_ROOT = project_root / 'experiments'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs')
        cfg.TRAIN.INPUT_SIZE_TARGET = (960, 540)
        cfg.TRAIN.INFO_TARGET = str(project_root /
                                    'advent/dataset/compound_list/info.json')

    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} dataset")

    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{cfg.TRAIN.OCDA_METHOD}'

    if args.exp_suffix:
        cfg.EXP_NAME += f'_{args.exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TRAIN.SNAPSHOT_DIR == '':
        cfg.TRAIN.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR, exist_ok=True)
    # tensorboard
    if args.tensorboard:
        if cfg.TRAIN.TENSORBOARD_LOGDIR == '':
            cfg.TRAIN.TENSORBOARD_LOGDIR = osp.join(cfg.EXP_ROOT_LOGS,
                                                    'tensorboard',
                                                    cfg.EXP_NAME)
        os.makedirs(cfg.TRAIN.TENSORBOARD_LOGDIR, exist_ok=True)
        if args.viz_every_iter is not None:
            cfg.TRAIN.TENSORBOARD_VIZRATE = args.viz_every_iter
    else:
        cfg.TRAIN.TENSORBOARD_LOGDIR = ''

    print('Using config:')
    pprint.pprint(cfg)

    # INIT
    _init_fn = None
    if not args.random_train:
        torch.manual_seed(cfg.TRAIN.RANDOM_SEED)
        torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED)
        np.random.seed(cfg.TRAIN.RANDOM_SEED)
        random.seed(cfg.TRAIN.RANDOM_SEED)

        def _init_fn(worker_id):
            np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # LOAD SEGMENTATION NET
    if cfg.TRAIN.MODEL == 'DeepLabv2':
        model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TRAIN.MULTI_LEVEL)
        saved_state_dict = torch.load(cfg.TRAIN.RESTORE_FROM)
        if 'DeepLab_resnet_pretrained_imagenet' in cfg.TRAIN.RESTORE_FROM:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)
        else:
            model.load_state_dict(saved_state_dict)
    elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG':
        model = get_deeplab_v2_vgg(cfg=cfg,
                                   num_classes=cfg.NUM_CLASSES,
                                   pretrained_model=cfg.TRAIN_VGG_PRE_MODEL)

        if cfg.TRAIN.SELF_TRAINING:
            path = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.TRAIN.RESTORE_FROM_SELF)
            saved_state_dict = torch.load(path)
            model.load_state_dict(saved_state_dict, strict=False)
            trg_list = cfg.DATA_LIST_TARGET_ORDER
            print("self-training model loaded: {} ".format(path))
        else:
            trg_list = cfg.DATA_LIST_TARGET
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TRAIN.MODEL}")

    print("model: ")
    print(model)
    print('Model loaded')

    ########  DATALOADERS  ########
    # GTA5: 24,966: 274,626 / 24,966 = 11 epoch

    # self-training : target data shuffle
    shuffle = cfg.TRAIN.SHUFFLE
    if cfg.TRAIN.SELF_TRAINING:
        max_iteration = None
    else:
        max_iteration = cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_SOURCE

    source_dataset = GTA5DataSet(root=cfg.DATA_DIRECTORY_SOURCE,
                                 list_path=cfg.DATA_LIST_SOURCE,
                                 set=cfg.TRAIN.SET_SOURCE,
                                 max_iters=max_iteration,
                                 crop_size=cfg.TRAIN.INPUT_SIZE_SOURCE,
                                 mean=cfg.TRAIN.IMG_MEAN)
    source_loader = data.DataLoader(source_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_SOURCE,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=_init_fn)
    if cfg.TARGET == "BDD":
        # GTA5: 14,697: 264,546 / 14,697 = 18 epoch
        target_dataset = BDDdataset(root=cfg.DATA_DIRECTORY_TARGET,
                                    list_path=trg_list,
                                    set=cfg.TRAIN.SET_TARGET,
                                    info_path=cfg.TRAIN.INFO_TARGET,
                                    max_iters=max_iteration,
                                    crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                    mean=cfg.TRAIN.IMG_MEAN)
        target_loader = data.DataLoader(target_dataset,
                                        batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                        num_workers=cfg.NUM_WORKERS,
                                        shuffle=shuffle,
                                        pin_memory=True,
                                        worker_init_fn=_init_fn)
    elif cfg.TARGET == 'Cityscapes':
        target_dataset = CityscapesDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            list_path=cfg.DATA_LIST_TARGET,
            set=cfg.TRAIN.SET_TARGET,
            info_path=cfg.TRAIN.INFO_TARGET,
            max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_TARGET,
            crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
            mean=cfg.TRAIN.IMG_MEAN)
        target_loader = data.DataLoader(target_dataset,
                                        batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                        num_workers=cfg.NUM_WORKERS,
                                        shuffle=True,
                                        pin_memory=True,
                                        worker_init_fn=_init_fn)
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} datasets")

    with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, 'train_cfg.yml'),
              'w') as yaml_file:
        yaml.dump(cfg, yaml_file, default_flow_style=False)

    # UDA TRAINING
    train_domain_adaptation(model, source_loader, target_loader, cfg)