Exemple #1
0
def main(config_file, exp_suffix):
    # LOAD ARGS
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)
    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TEST.SNAPSHOT_DIR[0] == '':
        cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT,
                                            cfg.EXP_NAME)
        os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True)

    print('Using config:')
    pprint.pprint(cfg)
    # load models
    models = []
    n_models = len(cfg.TEST.MODEL)
    if cfg.TEST.MODE == 'best':
        assert n_models == 1, 'Not yet supported'
    for i in range(n_models):
        if cfg.TEST.MODEL[i] == 'DeepLabv2':
            model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                                   multi_level=cfg.TEST.MULTI_LEVEL[i])
        else:
            raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}")
        models.append(model)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    if cfg.TARGET == 'Mapillary':
        test_dataset = MapillaryDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            set=cfg.TEST.SET_TARGET,
            crop_size=cfg.TEST.INPUT_SIZE_TARGET,
            mean=cfg.TEST.IMG_MEAN,
            labels_size=cfg.TEST.OUTPUT_SIZE_TARGET,
            scale_label=False)
    else:
        test_dataset = CityscapesDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            list_path=cfg.DATA_LIST_TARGET,
            set=cfg.TEST.SET_TARGET,
            info_path=cfg.TEST.INFO_TARGET,
            crop_size=cfg.TEST.INPUT_SIZE_TARGET,
            mean=cfg.TEST.IMG_MEAN,
            labels_size=cfg.TEST.OUTPUT_SIZE_TARGET,
            num_classes=cfg.NUM_CLASSES)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                  num_workers=cfg.NUM_WORKERS,
                                  shuffle=False,
                                  pin_memory=True)
    # eval
    evaluate_domain_adaptation(models, test_loader, cfg)
Exemple #2
0
Fichier : esl.py Projet : CV-IP/ESL
def main(config_file, exp_suffix):
    # LOAD ARGS
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)
    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'

    print('Using config:')
    pprint.pprint(cfg)
    # load model
    model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                           multi_level=cfg.ESL.MULTI_LEVEL)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    if cfg.TARGET == 'Mapillary':
        test_dataset = MapillaryDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                        set=cfg.ESL.SET_TARGET,
                                        crop_size=cfg.ESL.INPUT_SIZE_TARGET,
                                        mean=cfg.ESL.IMG_MEAN,
                                        labels_size=cfg.ESL.OUTPUT_SIZE_TARGET)
    else:
        test_dataset = CityscapesDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            list_path=cfg.DATA_LIST_TARGET,
            set=cfg.ESL.SET_TARGET,
            info_path=cfg.ESL.INFO_TARGET,
            crop_size=cfg.ESL.INPUT_SIZE_TARGET,
            mean=cfg.ESL.IMG_MEAN,
            labels_size=cfg.ESL.OUTPUT_SIZE_TARGET,
            num_classes=cfg.NUM_CLASSES)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=cfg.ESL.BATCH_SIZE_TARGET,
                                  num_workers=cfg.NUM_WORKERS,
                                  shuffle=False,
                                  pin_memory=True)
    # eval
    extract_pseudo_labels(model, test_loader, cfg)
Exemple #3
0
def main():
    # LOAD ARGS
    args = get_arguments()
    print('Called with args:')
    print(args)

    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)
    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'

    if args.exp_suffix:
        cfg.EXP_NAME += f'_{args.exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TRAIN.SNAPSHOT_DIR == '':
        cfg.TRAIN.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR, exist_ok=True)
    # tensorboard
    if args.tensorboard:
        if cfg.TRAIN.TENSORBOARD_LOGDIR == '':
            cfg.TRAIN.TENSORBOARD_LOGDIR = osp.join(cfg.EXP_ROOT_LOGS,
                                                    'tensorboard',
                                                    cfg.EXP_NAME)
        os.makedirs(cfg.TRAIN.TENSORBOARD_LOGDIR, exist_ok=True)
        if args.viz_every_iter is not None:
            cfg.TRAIN.TENSORBOARD_VIZRATE = args.viz_every_iter
    else:
        cfg.TRAIN.TENSORBOARD_LOGDIR = ''
    print('Using config:')
    pprint.pprint(cfg)

    # INIT
    _init_fn = None
    if not args.random_train:
        torch.manual_seed(cfg.TRAIN.RANDOM_SEED)
        torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED)
        np.random.seed(cfg.TRAIN.RANDOM_SEED)
        random.seed(cfg.TRAIN.RANDOM_SEED)

        def _init_fn(worker_id):
            np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # LOAD SEGMENTATION NET
    assert osp.exists(
        cfg.TRAIN.RESTORE_FROM), f'Missing init model {cfg.TRAIN.RESTORE_FROM}'
    if cfg.TRAIN.MODEL == 'DeepLabv2':
        model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TRAIN.MULTI_LEVEL)
        saved_state_dict = torch.load(cfg.TRAIN.RESTORE_FROM)
        if 'DeepLab_resnet_pretrained_imagenet' in cfg.TRAIN.RESTORE_FROM:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)
        else:
            model.load_state_dict(saved_state_dict)
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TRAIN.MODEL}")
    print('Model loaded')

    # DATALOADERS
    source_dataset = GTA5DataSet(root=cfg.DATA_DIRECTORY_SOURCE,
                                 list_path=cfg.DATA_LIST_SOURCE,
                                 set=cfg.TRAIN.SET_SOURCE,
                                 max_iters=cfg.TRAIN.MAX_ITERS *
                                 cfg.TRAIN.BATCH_SIZE_SOURCE,
                                 crop_size=cfg.TRAIN.INPUT_SIZE_SOURCE,
                                 mean=cfg.TRAIN.IMG_MEAN)
    source_loader = data.DataLoader(source_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_SOURCE,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=_init_fn)

    target_dataset = CityscapesDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=cfg.DATA_LIST_TARGET,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       max_iters=cfg.TRAIN.MAX_ITERS *
                                       cfg.TRAIN.BATCH_SIZE_TARGET,
                                       crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)
    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=_init_fn)

    with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, 'train_cfg.yml'),
              'w') as yaml_file:
        yaml.dump(cfg, yaml_file, default_flow_style=False)

    # UDA TRAINING
    train_domain_adaptation(model, source_loader, target_loader, cfg)
Exemple #4
0
def main():

    # LOAD ARGS
    args = get_arguments()
    config_file = args.cfg
    exp_suffix = args.exp_suffix
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)
    # auto-generate exp name if not specified
    # pdb.set_trace()

    if cfg.EXP_NAME == '':
        if args.MBT:  # when to train a model on pseudo label from MBT
            cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_MBT_THRESH_{args.thres}_ROUND_{args.round}'
        else:
            args.LB = str(args.LB).replace('.', '_')
            cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}_THRESH_{args.thres}_ROUND_{args.round}'
    # ----------------------------------------------------------------#
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'
    # auto-generate snapshot path if not specified
    # pdb.set_trace()
    if cfg.TEST.SNAPSHOT_DIR[0] == '':
        cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT,
                                            cfg.EXP_NAME)
        os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True)

    print('Using config:')
    pprint.pprint(cfg)
    # load models
    models = []
    n_models = len(cfg.TEST.MODEL)
    if cfg.TEST.MODE == 'best':
        assert n_models == 1, 'Not yet supported'
    for i in range(n_models):
        if cfg.TEST.MODEL[i] == 'DeepLabv2':
            model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                                   multi_level=cfg.TEST.MULTI_LEVEL[i])
        else:
            raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}")
        models.append(model)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    # pdb.set_trace()
    # ----------------------------------------------------------------#
    test_dataset = CityscapesDataSet(
        args=args,
        root=cfg.DATA_DIRECTORY_TARGET,
        list_path='../ADVENT/advent/dataset/cityscapes_list/{}.txt',
        set=cfg.TEST.SET_TARGET,
        info_path=cfg.TEST.INFO_TARGET,
        crop_size=cfg.TEST.INPUT_SIZE_TARGET,
        mean=cfg.TEST.IMG_MEAN,
        labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                  num_workers=cfg.NUM_WORKERS,
                                  shuffle=False,
                                  pin_memory=True)
    # eval
    # pdb.set_trace()
    evaluate_domain_adaptation(models, test_loader, cfg)
Exemple #5
0
def main(args):
    # load configuration file
    device = cfg.GPU_ID
    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)

    if not os.path.exists(save_dir % (args.FDA_mode, args.round)):
        os.mkdir(save_dir % (args.FDA_mode, args.round))
    # ----------------------------------------------------------------#
    args.LB = str(args.MBT)  # set args.LB = 'MBT'
    SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32)
    SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1))

    ###################### here, replace by restoring three different model#####################
    if args.round == 0:  # first round of SSL
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}'

    elif args.round > 0:  # when SSL round is higher than 0

        # SOURCE and TARGET are no longer GTA and Cityscape, but are easy and hard split
        cfg.SOURCE = 'CityscapesEasy'
        cfg.TARGET = 'CityscapesHard'
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}_THRESH_{str(thresholding)}_ROUND_{args.round - 1}'
    else:
        raise KeyError()

    ##########################################################################################################################################
    # ----------------------------------------------------------------#
    cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)

    # load model with parameters trained from Inter-domain adaptation
    model_gen = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TEST.MULTI_LEVEL)

    restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0],
                            f'model_{args.best_iter}.pth')

    print("Loading the generator:", restore_from)

    load_checkpoint_for_evaluation(model_gen, restore_from, device)

    # load data
    target_dataset = CityscapesDataSet(args=args,
                                       root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=cfg.DATA_LIST_TARGET,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       max_iters=None,
                                       crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=None)

    target_loader_iter = enumerate(target_loader)

    # upsampling layer
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # ---------------------------------------------------------------------------------------------------------------#

    # step 1. entropy-ranking: split the target dataset into easy and hard cases.

    entropy_list = []
    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch

        # normalize the image before fed into the trained model
        B, C, H, W = image.shape
        mean_image = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            image -= mean_image

        elif args.FDA_mode == 'off':
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            image = image

        else:
            raise KeyError()

        with torch.no_grad():
            _, pred_trg_main = model_gen(
                image.cuda(device))  # shape(pred_trg_main) = (1, 19, 65, 129)
            pred_trg_main = interp_target(
                pred_trg_main)  # shape(pred_trg_main) = (1, 19, 512, 1024)
            if args.normalize == True:
                normalizor = (11 -
                              len(find_rare_class(pred_trg_main))) / 11.0 + 0.5
            else:
                normalizor = 1
            pred_trg_entropy = prob_2_entropy(F.softmax(pred_trg_main))
            entropy_list.append(
                (name[0], pred_trg_entropy.mean().item() * normalizor))
            # colorize_save(pred_trg_main, name[0], args.FDA_mode)

    # split the enntropy_list into
    _, easy_split = cluster_subdomain(entropy_list, args, thresholding)

    # ---------------------------------------------------------------------------------------------------------------#

    # step2. apply thresholding(either top 66% or confidence score above 0.9) to easy-split target dataset and save them.

    predicted_label = np.zeros(
        (len(easy_split), 512,
         1024))  # (512, 1024) is the size of target output
    predicted_prob = np.zeros((len(easy_split), 512, 1024))
    image_name = []
    idx = 0

    target_loader_iter = enumerate(target_loader)

    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch

        if name[0] not in easy_split:  # only compute the images that belongs to easy-split
            continue

        # normalize the image before fed into the trained model
        B, C, H, W = image.shape
        mean_image = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            image -= mean_image

        elif args.FDA_mode == 'off':
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            image = image

        else:
            raise KeyError()

        with torch.no_grad():
            _, pred_trg_main = model_gen(
                image.cuda(device))  # shape(pred_trg_main) = (1, 19, 65, 129)
            pred_trg_main = F.softmax(interp_target(pred_trg_main), dim=1).cpu(
            ).data[0].numpy()  # shape(pred_trg_main) = (1, 19, 512, 1024)
            pred_trg_main = pred_trg_main.transpose(
                1, 2, 0)  # shape(pred_trg_main) = (512, 1024, 19)
            label, prob = np.argmax(pred_trg_main,
                                    axis=2), np.max(pred_trg_main, axis=2)
            predicted_label[idx] = label
            predicted_prob[idx] = prob
            image_name.append(name[0])
            idx += 1

    assert len(easy_split) == len(
        image_name)  # check whether all images in easy-split are processed

    # compute the threshold for each label
    thres = []
    for i in range(cfg.NUM_CLASSES):
        x = predicted_prob[predicted_label == i]
        if len(x) == 0:
            thres.append(0)
            continue
        x = np.sort(x)
        thres.append(
            x[np.int(np.round(len(x) * 0.66))]
        )  # thres contains the thresholding values by labels in corresponding entry:thres[label]
    print(thres)
    thres = np.array(thres)
    thres[thres > 0.9] = 0.9

    print(thres)
    colorize_save_with_thresholding(easy_split, thres, predicted_label,
                                    predicted_prob, image_name, args)
Exemple #6
0
def main(args):

    # load configuration file
    device = cfg.GPU_ID
    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)

    if not os.path.exists(save_dir % (args.FDA_mode)):
        os.mkdir(save_dir % (args.FDA_mode))
    # ----------------------------------------------------------------#
    args.LB = str(args.LB).replace('.', '_')
    SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32)
    SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1))

    if args.round == 0:  # first round of SSL
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}'

    elif args.round > 0:  # when SSL round is higher than 0

        # SOURCE and TARGET are no longer GTA and Cityscape, but are easy and hard split
        cfg.SOURCE = 'CityscapesEasy'
        cfg.TARGET = 'CityscapesHard'
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}_THRESH_{str(thresholding)}_ROUND_{args.round - 1}'
    else:
        raise KeyError()

    #cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}'
    # ----------------------------------------------------------------#
    cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)

    # load model with parameters trained from Inter-domain adaptation
    model_gen = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TEST.MULTI_LEVEL)

    restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0],
                            f'model_{args.best_iter}.pth')

    print("Loading the generator:", restore_from)

    load_checkpoint_for_evaluation(model_gen, restore_from, device)

    # load data
    target_dataset = CityscapesDataSet(args=args,
                                       root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=cfg.DATA_LIST_TARGET,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       max_iters=None,
                                       crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=None)

    target_loader_iter = enumerate(target_loader)

    # upsampling layer
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    entropy_list = []
    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch

        # ----------------------------------------------------------------#
        """
        normalize the image before fed into the trained model
        """
        B, C, H, W = image.shape
        mean_image = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            image -= mean_image

        elif args.FDA_mode == 'off':
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            image = image

        else:
            raise KeyError()
        # ----------------------------------------------------------------#

        with torch.no_grad():
            _, pred_trg_main = model_gen(image.cuda(device))
            pred_trg_main = interp_target(pred_trg_main)
            if args.normalize == True:
                normalizor = (11 -
                              len(find_rare_class(pred_trg_main))) / 11.0 + 0.5
            else:
                normalizor = 1
            pred_trg_entropy = prob_2_entropy(F.softmax(pred_trg_main))
            entropy_list.append(
                (name[0], pred_trg_entropy.mean().item() * normalizor))
            colorize_save(pred_trg_main, name[0], args)

    # split the enntropy_list into
    cluster_subdomain(entropy_list, args, thresholding)
Exemple #7
0
def main(args):

    # load configuration file
    device = cfg.GPU_ID
    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)

    if not os.path.exists('./color_masks'):
        os.mkdir('./color_masks')

    cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'
    cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)

    # load model with parameters trained from Inter-domain adaptation
    model_gen = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TEST.MULTI_LEVEL)

    restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0],
                            f'model_{args.best_iter}.pth')

    print("Loading the generator:", restore_from)

    load_checkpoint_for_evaluation(model_gen, restore_from, device)

    # load data
    target_dataset = CityscapesDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=cfg.DATA_LIST_TARGET,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       max_iters=None,
                                       crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=None)

    target_loader_iter = enumerate(target_loader)

    # upsampling layer
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    entropy_list = []
    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch
        with torch.no_grad():
            _, pred_trg_main = model_gen(image.cuda(device))
            pred_trg_main = interp_target(pred_trg_main)
            if args.normalize == True:
                normalizor = (11 -
                              len(find_rare_class(pred_trg_main))) / 11.0 + 0.5
            else:
                normalizor = 1
            pred_trg_entropy = prob_2_entropy(F.softmax(pred_trg_main))
            entropy_list.append(
                (name[0], pred_trg_entropy.mean().item() * normalizor))
            colorize_save(pred_trg_main, name[0])

    # split the enntropy_list into
    cluster_subdomain(entropy_list, args.lambda1)
Exemple #8
0
def main(config_file, exp_suffix):
    # LOAD ARGS
    assert config_file is not None, 'Missing cfg file'
    cfg_from_file(config_file)

    cfg.NUM_WORKERS = args.num_workers

    ### dataset settings
    cfg.SOURCE = args.source
    cfg.TARGET = args.target
    ## source config
    if cfg.SOURCE == 'GTA':
        cfg.DATA_LIST_SOURCE = str(project_root / 'advent/dataset/gta5_list/{}.txt')
        cfg.DATA_DIRECTORY_SOURCE = str(project_root / 'data/GTA5')

    elif cfg.SOURCE == 'SYNTHIA':
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")
    else:
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")

    ## target config
    if cfg.TARGET == 'Cityscapes':
        cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/cityscapes_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/cityscapes')
        cfg.EXP_ROOT = project_root / 'experiments_G2C'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots_G2C')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs_G2C')
        cfg.TEST.INPUT_SIZE_TARGET = (1024, 512)
        cfg.TEST.OUTPUT_SIZE_TARGET = (2048, 1024)
        cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/cityscapes_list/info.json')

    elif cfg.TARGET == 'BDD':
        cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/compound_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/bdd/Compound')
        cfg.EXP_ROOT = project_root / 'experiments'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs')
        cfg.TEST.INPUT_SIZE_TARGET = (960, 540)
        cfg.TEST.OUTPUT_SIZE_TARGET = (1280, 720)
        cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/compound_list/info.json')

    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} dataset")


    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{cfg.TRAIN.OCDA_METHOD}'
    if exp_suffix:
        cfg.EXP_NAME += f'_{exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TEST.SNAPSHOT_DIR[0] == '':
        cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
        os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True)

    print('Using config:')
    pprint.pprint(cfg)
    # load models
    models = []
    n_models = len(cfg.TEST.MODEL)
    if cfg.TEST.MODE == 'best':
        assert n_models == 1, 'Not yet supported'
    for i in range(n_models):
        if cfg.TEST.MODEL[i] == 'DeepLabv2':
            model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                                   multi_level=cfg.TEST.MULTI_LEVEL[i])

        elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG':
            model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL)
        else:
            raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}")
        models.append(model)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # dataloaders
    if cfg.TARGET == 'Cityscapes':
        test_dataset = CityscapesDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                         list_path=cfg.DATA_LIST_TARGET,
                                         set=cfg.TEST.SET_TARGET,
                                         info_path=cfg.TEST.INFO_TARGET,
                                         crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                         mean=cfg.TEST.IMG_MEAN,
                                         labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
        test_loader = data.DataLoader(test_dataset,
                                      batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                      num_workers=cfg.NUM_WORKERS,
                                      shuffle=False,
                                      pin_memory=True)
    elif cfg.TARGET == 'BDD':
        test_dataset = BDDdataset(root=cfg.DATA_DIRECTORY_TARGET,
                                           list_path=cfg.DATA_LIST_TARGET,
                                           set=cfg.TEST.SET_TARGET,
                                           info_path=cfg.TEST.INFO_TARGET,
                                           crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                           mean=cfg.TEST.IMG_MEAN,
                                           labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)
        test_loader = data.DataLoader(test_dataset,
                                        batch_size=cfg.TEST.BATCH_SIZE_TARGET,
                                        num_workers=cfg.NUM_WORKERS,
                                        shuffle=False,
                                        pin_memory=True)
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} datasets")
    # eval
    evaluate_domain_adaptation(models, test_loader, cfg)
Exemple #9
0
def main():
    # LOAD ARGS
    args = get_arguments()
    print('Called with args:')
    print(args)

    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)
    cfg.NUM_WORKERS = args.num_workers
    if args.option is not None:
        cfg.TRAIN.OPTION = args.option
    cfg.TRAIN.LAMBDA_BOUNDARY = args.LAMBDA_BOUNDARY
    cfg.TRAIN.LAMBDA_DICE = args.LAMBDA_DICE

    ## gan method settings
    cfg.GAN = args.gan
    if cfg.GAN == 'gan':
        cfg.TRAIN.LAMBDA_ADV_MAIN = 0.001  # GAN
    elif cfg.GAN == 'lsgan':
        cfg.TRAIN.LAMBDA_ADV_MAIN = 0.01  # LS-GAN
    else:
        raise NotImplementedError(f"Not Supported gan method")

    ### dataset settings
    cfg.SOURCE = args.source
    cfg.TARGET = args.target
    ## source config
    if cfg.SOURCE == 'GTA':
        cfg.DATA_LIST_SOURCE = str(project_root /
                                   'advent/dataset/gta5_list/{}.txt')
        cfg.DATA_DIRECTORY_SOURCE = str(project_root / 'data/GTA5')
        cfg.TRAIN.INPUT_SIZE_SOURCE = (1280, 720)

    elif cfg.SOURCE == 'SYNTHIA':
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")
    else:
        raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset")

    ## target config
    if cfg.TARGET == 'Cityscapes':
        cfg.DATA_LIST_TARGET = str(project_root /
                                   'advent/dataset/cityscapes_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/cityscapes')
        cfg.EXP_ROOT = project_root / 'experiments_G2C'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots_G2C')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs_G2C')
        cfg.TRAIN.INPUT_SIZE_TARGET = (1024, 512)
        cfg.TRAIN.INFO_TARGET = str(project_root /
                                    'advent/dataset/cityscapes_list/info.json')

        cfg.TEST.INPUT_SIZE_TARGET = (1024, 512)
        cfg.TEST.OUTPUT_SIZE_TARGET = (2048, 1024)
        cfg.TEST.INFO_TARGET = str(project_root /
                                   'advent/dataset/cityscapes_list/info.json')

    elif cfg.TARGET == 'BDD':
        cfg.DATA_LIST_TARGET = str(project_root /
                                   'advent/dataset/compound_list/{}.txt')
        cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/bdd/Compound')
        cfg.EXP_ROOT = project_root / 'experiments'
        cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots')
        cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs')
        cfg.TRAIN.INPUT_SIZE_TARGET = (960, 540)
        cfg.TRAIN.INFO_TARGET = str(project_root /
                                    'advent/dataset/compound_list/info.json')

    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} dataset")

    # auto-generate exp name if not specified
    if cfg.EXP_NAME == '':
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{cfg.TRAIN.OCDA_METHOD}'

    if args.exp_suffix:
        cfg.EXP_NAME += f'_{args.exp_suffix}'
    # auto-generate snapshot path if not specified
    if cfg.TRAIN.SNAPSHOT_DIR == '':
        cfg.TRAIN.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR, exist_ok=True)
    # tensorboard
    if args.tensorboard:
        if cfg.TRAIN.TENSORBOARD_LOGDIR == '':
            cfg.TRAIN.TENSORBOARD_LOGDIR = osp.join(cfg.EXP_ROOT_LOGS,
                                                    'tensorboard',
                                                    cfg.EXP_NAME)
        os.makedirs(cfg.TRAIN.TENSORBOARD_LOGDIR, exist_ok=True)
        if args.viz_every_iter is not None:
            cfg.TRAIN.TENSORBOARD_VIZRATE = args.viz_every_iter
    else:
        cfg.TRAIN.TENSORBOARD_LOGDIR = ''

    print('Using config:')
    pprint.pprint(cfg)

    # INIT
    _init_fn = None
    if not args.random_train:
        torch.manual_seed(cfg.TRAIN.RANDOM_SEED)
        torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED)
        np.random.seed(cfg.TRAIN.RANDOM_SEED)
        random.seed(cfg.TRAIN.RANDOM_SEED)

        def _init_fn(worker_id):
            np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id)

    if os.environ.get('ADVENT_DRY_RUN', '0') == '1':
        return

    # LOAD SEGMENTATION NET
    if cfg.TRAIN.MODEL == 'DeepLabv2':
        model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TRAIN.MULTI_LEVEL)
        saved_state_dict = torch.load(cfg.TRAIN.RESTORE_FROM)
        if 'DeepLab_resnet_pretrained_imagenet' in cfg.TRAIN.RESTORE_FROM:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)
        else:
            model.load_state_dict(saved_state_dict)
    elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG':
        model = get_deeplab_v2_vgg(cfg=cfg,
                                   num_classes=cfg.NUM_CLASSES,
                                   pretrained_model=cfg.TRAIN_VGG_PRE_MODEL)

        if cfg.TRAIN.SELF_TRAINING:
            path = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.TRAIN.RESTORE_FROM_SELF)
            saved_state_dict = torch.load(path)
            model.load_state_dict(saved_state_dict, strict=False)
            trg_list = cfg.DATA_LIST_TARGET_ORDER
            print("self-training model loaded: {} ".format(path))
        else:
            trg_list = cfg.DATA_LIST_TARGET
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TRAIN.MODEL}")

    print("model: ")
    print(model)
    print('Model loaded')

    ########  DATALOADERS  ########
    # GTA5: 24,966: 274,626 / 24,966 = 11 epoch

    # self-training : target data shuffle
    shuffle = cfg.TRAIN.SHUFFLE
    if cfg.TRAIN.SELF_TRAINING:
        max_iteration = None
    else:
        max_iteration = cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_SOURCE

    source_dataset = GTA5DataSet(root=cfg.DATA_DIRECTORY_SOURCE,
                                 list_path=cfg.DATA_LIST_SOURCE,
                                 set=cfg.TRAIN.SET_SOURCE,
                                 max_iters=max_iteration,
                                 crop_size=cfg.TRAIN.INPUT_SIZE_SOURCE,
                                 mean=cfg.TRAIN.IMG_MEAN)
    source_loader = data.DataLoader(source_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_SOURCE,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=_init_fn)
    if cfg.TARGET == "BDD":
        # GTA5: 14,697: 264,546 / 14,697 = 18 epoch
        target_dataset = BDDdataset(root=cfg.DATA_DIRECTORY_TARGET,
                                    list_path=trg_list,
                                    set=cfg.TRAIN.SET_TARGET,
                                    info_path=cfg.TRAIN.INFO_TARGET,
                                    max_iters=max_iteration,
                                    crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                    mean=cfg.TRAIN.IMG_MEAN)
        target_loader = data.DataLoader(target_dataset,
                                        batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                        num_workers=cfg.NUM_WORKERS,
                                        shuffle=shuffle,
                                        pin_memory=True,
                                        worker_init_fn=_init_fn)
    elif cfg.TARGET == 'Cityscapes':
        target_dataset = CityscapesDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            list_path=cfg.DATA_LIST_TARGET,
            set=cfg.TRAIN.SET_TARGET,
            info_path=cfg.TRAIN.INFO_TARGET,
            max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_TARGET,
            crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
            mean=cfg.TRAIN.IMG_MEAN)
        target_loader = data.DataLoader(target_dataset,
                                        batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                        num_workers=cfg.NUM_WORKERS,
                                        shuffle=True,
                                        pin_memory=True,
                                        worker_init_fn=_init_fn)
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TARGET} datasets")

    with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, 'train_cfg.yml'),
              'w') as yaml_file:
        yaml.dump(cfg, yaml_file, default_flow_style=False)

    # UDA TRAINING
    train_domain_adaptation(model, source_loader, target_loader, cfg)
Exemple #10
0
def main():
    # LOAD ARGS
    args = get_arguments()
    print("Called with args:")
    print(args)

    assert args.cfg is not None, "Missing cfg file"
    cfg_from_file(args.cfg)
    # auto-generate exp name if not specified
    if cfg.EXP_NAME == "":
        cfg.EXP_NAME = f"{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}"

    if args.exp_suffix:
        cfg.EXP_NAME += f"_{args.exp_suffix}"
    # auto-generate snapshot path if not specified
    if cfg.TRAIN.SNAPSHOT_DIR == "":
        cfg.TRAIN.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)
    os.makedirs(cfg.TRAIN.SNAPSHOT_DIR, exist_ok=True)
    # tensorboard
    if args.tensorboard:
        if cfg.TRAIN.TENSORBOARD_LOGDIR == "":
            cfg.TRAIN.TENSORBOARD_LOGDIR = osp.join(
                cfg.EXP_ROOT_LOGS, "tensorboard", cfg.EXP_NAME
            )
        os.makedirs(cfg.TRAIN.TENSORBOARD_LOGDIR, exist_ok=True)
        if args.viz_every_iter is not None:
            cfg.TRAIN.TENSORBOARD_VIZRATE = args.viz_every_iter
    else:
        cfg.TRAIN.TENSORBOARD_LOGDIR = ""
    print("Using config:")
    pprint.pprint(cfg)

    # INIT
    _init_fn = None
    if not args.random_train:
        torch.manual_seed(cfg.TRAIN.RANDOM_SEED)
        torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED)
        np.random.seed(cfg.TRAIN.RANDOM_SEED)
        random.seed(cfg.TRAIN.RANDOM_SEED)

        def _init_fn(worker_id):
            np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id)

    if os.environ.get("DADA_DRY_RUN", "0") == "1":
        return

    # LOAD SEGMENTATION NET
    assert osp.exists(
        cfg.TRAIN.RESTORE_FROM
    ), f"Missing init model {cfg.TRAIN.RESTORE_FROM}"
    if cfg.TRAIN.MODEL == "DeepLabv2_depth":
        model = get_deeplab_v2_depth(
            num_classes=cfg.NUM_CLASSES,
            multi_level=cfg.TRAIN.MULTI_LEVEL
        )
        saved_state_dict = torch.load(cfg.TRAIN.RESTORE_FROM)
        if "DeepLab_resnet_pretrained_imagenet" in cfg.TRAIN.RESTORE_FROM:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split(".")
                if not i_parts[1] == "layer5":
                    new_params[".".join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)
        else:
            model.load_state_dict(saved_state_dict)
    elif cfg.TRAIN.MODEL == "DeepLabv2":
        model = get_deeplab_v2(
            num_classes=cfg.NUM_CLASSES,
            multi_level=cfg.TRAIN.MULTI_LEVEL
        )
        saved_state_dict = torch.load(cfg.TRAIN.RESTORE_FROM)
        if "DeepLab_resnet_pretrained_imagenet" in cfg.TRAIN.RESTORE_FROM:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split(".")
                if not i_parts[1] == "layer5":
                    new_params[".".join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)
        else:
            model.load_state_dict(saved_state_dict)
    else:
        raise NotImplementedError(f"Not yet supported {cfg.TRAIN.MODEL}")
    print("Model loaded")

    # DATALOADERS
    source_dataset = SYNTHIADataSetDepth(
        root=cfg.DATA_DIRECTORY_SOURCE,
        list_path=cfg.DATA_LIST_SOURCE,
        set=cfg.TRAIN.SET_SOURCE,
        num_classes=cfg.NUM_CLASSES,
        max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_SOURCE,
        crop_size=cfg.TRAIN.INPUT_SIZE_SOURCE,
        mean=cfg.TRAIN.IMG_MEAN,
        use_depth=cfg.USE_DEPTH,
    )
    source_loader = data.DataLoader(
        source_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_SOURCE,
        num_workers=cfg.NUM_WORKERS,
        shuffle=True,
        pin_memory=True,
        worker_init_fn=_init_fn,
    )

    if cfg.TARGET == 'Cityscapes':
        target_dataset = CityscapesDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            list_path=cfg.DATA_LIST_TARGET,
            set=cfg.TRAIN.SET_TARGET,
            info_path=cfg.TRAIN.INFO_TARGET,
            max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_TARGET,
            crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
            mean=cfg.TRAIN.IMG_MEAN
        )
    elif cfg.TARGET == 'Mapillary':
        target_dataset = MapillaryDataSet(
            root=cfg.DATA_DIRECTORY_TARGET,
            list_path=cfg.DATA_LIST_TARGET,
            set=cfg.TRAIN.SET_TARGET,
            info_path=cfg.TRAIN.INFO_TARGET,
            max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_TARGET,
            crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
            mean=cfg.TRAIN.IMG_MEAN,
            scale_label=True
        )
    else:
        raise NotImplementedError(f"Not yet supported dataset {cfg.TARGET}")
    target_loader = data.DataLoader(
        target_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
        num_workers=cfg.NUM_WORKERS,
        shuffle=True,
        pin_memory=True,
        worker_init_fn=_init_fn,
    )

    with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, "train_cfg.yml"), "w") as yaml_file:
        yaml.dump(cfg, yaml_file, default_flow_style=False)

    # UDA TRAINING
    if cfg.USE_DEPTH:
        train_domain_adaptation_with_depth(model, source_loader, target_loader, cfg)
    else:
        train_domain_adaptation(model, source_loader, target_loader, cfg)