def main(config_file, exp_suffix): # LOAD ARGS assert config_file is not None, 'Missing cfg file' cfg_from_file(config_file) # auto-generate exp name if not specified if cfg.EXP_NAME == '': cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}' if exp_suffix: cfg.EXP_NAME += f'_{exp_suffix}' # auto-generate snapshot path if not specified if cfg.TEST.SNAPSHOT_DIR[0] == '': cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME) os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True) print('Using config:') pprint.pprint(cfg) # load models models = [] n_models = len(cfg.TEST.MODEL) if cfg.TEST.MODE == 'best': assert n_models == 1, 'Not yet supported' for i in range(n_models): if cfg.TEST.MODEL[i] == 'DeepLabv2': model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES, multi_level=cfg.TEST.MULTI_LEVEL[i]) elif cfg.TEST.MODEL[i] == 'DeepLabv2_VGG': model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL) else: raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}") models.append(model) if os.environ.get('ADVENT_DRY_RUN', '0') == '1': return # dataloaders test_dataset = BDDataSet(root=cfg.DATA_DIRECTORY_TARGET, list_path=cfg.DATA_LIST_TARGET, set=cfg.TEST.SET_TARGET, info_path=cfg.TEST.INFO_TARGET, crop_size=cfg.TEST.INPUT_SIZE_TARGET, mean=cfg.TEST.IMG_MEAN, labels_size=cfg.TEST.OUTPUT_SIZE_TARGET) test_loader = data.DataLoader(test_dataset, batch_size=cfg.TEST.BATCH_SIZE_TARGET, num_workers=cfg.NUM_WORKERS, shuffle=False, pin_memory=True) # eval evaluate_domain_adaptation(models, test_loader, cfg)
def main(config_file, exp_suffix): # LOAD ARGS assert config_file is not None, 'Missing cfg file' cfg_from_file(config_file) cfg.NUM_WORKERS = args.num_workers ### dataset settings cfg.SOURCE = args.source cfg.TARGET = args.target ## source config if cfg.SOURCE == 'GTA': cfg.DATA_LIST_SOURCE = str(project_root / 'advent/dataset/gta5_list/{}.txt') cfg.DATA_DIRECTORY_SOURCE = str(project_root / 'data/GTA5') elif cfg.SOURCE == 'SYNTHIA': raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset") else: raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset") ## target config if cfg.TARGET == 'Cityscapes': cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/cityscapes_list/{}.txt') cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/cityscapes') cfg.EXP_ROOT = project_root / 'experiments_G2C' cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots_G2C') cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs_G2C') cfg.TEST.INPUT_SIZE_TARGET = (1024, 512) cfg.TEST.OUTPUT_SIZE_TARGET = (2048, 1024) cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/cityscapes_list/info.json') elif cfg.TARGET == 'BDD': cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/compound_list/{}.txt') cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/bdd/Compound') cfg.EXP_ROOT = project_root / 'experiments' cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots') cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs') cfg.TEST.INPUT_SIZE_TARGET = (960, 540) cfg.TEST.OUTPUT_SIZE_TARGET = (1280, 720) cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/compound_list/info.json') else: raise NotImplementedError(f"Not yet supported {cfg.TARGET} dataset") # auto-generate exp name if not specified if cfg.EXP_NAME == '': cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{cfg.TRAIN.OCDA_METHOD}' if exp_suffix: cfg.EXP_NAME += f'_{exp_suffix}' # auto-generate snapshot path if not specified if cfg.TEST.SNAPSHOT_DIR[0] == '': cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME) os.makedirs(cfg.TEST.SNAPSHOT_DIR[0], exist_ok=True) print('Using config:') pprint.pprint(cfg) # load models models = [] n_models = len(cfg.TEST.MODEL) if cfg.TEST.MODE == 'best': assert n_models == 1, 'Not yet supported' for i in range(n_models): if cfg.TEST.MODEL[i] == 'DeepLabv2': model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES, multi_level=cfg.TEST.MULTI_LEVEL[i]) elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG': model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL) else: raise NotImplementedError(f"Not yet supported {cfg.TEST.MODEL[i]}") models.append(model) if os.environ.get('ADVENT_DRY_RUN', '0') == '1': return # dataloaders if cfg.TARGET == 'Cityscapes': test_dataset = CityscapesDataSet(root=cfg.DATA_DIRECTORY_TARGET, list_path=cfg.DATA_LIST_TARGET, set=cfg.TEST.SET_TARGET, info_path=cfg.TEST.INFO_TARGET, crop_size=cfg.TEST.INPUT_SIZE_TARGET, mean=cfg.TEST.IMG_MEAN, labels_size=cfg.TEST.OUTPUT_SIZE_TARGET) test_loader = data.DataLoader(test_dataset, batch_size=cfg.TEST.BATCH_SIZE_TARGET, num_workers=cfg.NUM_WORKERS, shuffle=False, pin_memory=True) elif cfg.TARGET == 'BDD': test_dataset = BDDdataset(root=cfg.DATA_DIRECTORY_TARGET, list_path=cfg.DATA_LIST_TARGET, set=cfg.TEST.SET_TARGET, info_path=cfg.TEST.INFO_TARGET, crop_size=cfg.TEST.INPUT_SIZE_TARGET, mean=cfg.TEST.IMG_MEAN, labels_size=cfg.TEST.OUTPUT_SIZE_TARGET) test_loader = data.DataLoader(test_dataset, batch_size=cfg.TEST.BATCH_SIZE_TARGET, num_workers=cfg.NUM_WORKERS, shuffle=False, pin_memory=True) else: raise NotImplementedError(f"Not yet supported {cfg.TARGET} datasets") # eval evaluate_domain_adaptation(models, test_loader, cfg)
def main(): # LOAD ARGS args = get_arguments() print('Called with args:') print(args) assert args.cfg is not None, 'Missing cfg file' cfg_from_file(args.cfg) # auto-generate exp name if not specified if cfg.EXP_NAME == '': cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}' if args.exp_suffix: cfg.EXP_NAME += f'_{args.exp_suffix}' # auto-generate snapshot path if not specified if cfg.TRAIN.SNAPSHOT_DIR == '': cfg.TRAIN.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME) os.makedirs(cfg.TRAIN.SNAPSHOT_DIR, exist_ok=True) # tensorboard if args.tensorboard: if cfg.TRAIN.TENSORBOARD_LOGDIR == '': cfg.TRAIN.TENSORBOARD_LOGDIR = osp.join(cfg.EXP_ROOT_LOGS, 'tensorboard', cfg.EXP_NAME) os.makedirs(cfg.TRAIN.TENSORBOARD_LOGDIR, exist_ok=True) if args.viz_every_iter is not None: cfg.TRAIN.TENSORBOARD_VIZRATE = args.viz_every_iter else: cfg.TRAIN.TENSORBOARD_LOGDIR = '' print('Using config:') pprint.pprint(cfg) shuffle = cfg.TRAIN.SHUFFLE # INIT _init_fn = None if not args.random_train: torch.manual_seed(cfg.TRAIN.RANDOM_SEED) torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED) np.random.seed(cfg.TRAIN.RANDOM_SEED) random.seed(cfg.TRAIN.RANDOM_SEED) def _init_fn(worker_id): np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id) if os.environ.get('ADVENT_DRY_RUN', '0') == '1': return # LOAD SEGMENTATION NET # assert osp.exists(cfg.TRAIN.RESTORE_FROM), f'Missing init model {cfg.TRAIN.RESTORE_FROM}' if cfg.TRAIN.MODEL == 'DeepLabv2': model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES, multi_level=cfg.TRAIN.MULTI_LEVEL) saved_state_dict = torch.load(cfg.TRAIN.RESTORE_FROM) if 'DeepLab_resnet_pretrained_imagenet' in cfg.TRAIN.RESTORE_FROM: new_params = model.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') if not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] model.load_state_dict(new_params) else: model.load_state_dict(saved_state_dict) elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG': model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL) if cfg.TRAIN.SELF_TRAINING: path = osp.join(cfg.TRAIN.RESTORE_FROM_SELF, 'model_46000.pth') saved_state_dict = torch.load(path) model.load_state_dict(saved_state_dict, strict=False) trg_list = cfg.DATA_LIST_TARGET_ORDER else: trg_list = cfg.DATA_LIST_TARGET else: raise NotImplementedError(f"Not yet supported {cfg.TRAIN.MODEL}") print('Model loaded') # DATALOADERS source_dataset = GTA5DataSet(root=cfg.DATA_DIRECTORY_SOURCE, list_path=cfg.DATA_LIST_SOURCE, set=cfg.TRAIN.SET_SOURCE, # max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_SOURCE, max_iters=None, crop_size=cfg.TRAIN.INPUT_SIZE_SOURCE, mean=cfg.TRAIN.IMG_MEAN) source_loader = data.DataLoader(source_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_SOURCE, num_workers=cfg.NUM_WORKERS, shuffle=True, pin_memory=True, worker_init_fn=_init_fn) target_dataset = BDDataSet(root=cfg.DATA_DIRECTORY_TARGET, list_path=trg_list, set=cfg.TRAIN.SET_TARGET, info_path=cfg.TRAIN.INFO_TARGET, # max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_TARGET, max_iters=None, crop_size=cfg.TRAIN.INPUT_SIZE_TARGET, mean=cfg.TRAIN.IMG_MEAN) target_loader = data.DataLoader(target_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_TARGET, num_workers=cfg.NUM_WORKERS, shuffle=shuffle, pin_memory=True, worker_init_fn=_init_fn) with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, 'train_cfg.yml'), 'w') as yaml_file: yaml.dump(cfg, yaml_file, default_flow_style=False) # UDA TRAINING train_domain_adaptation(model, source_loader, target_loader, cfg)
def main(): # LOAD ARGS args = get_arguments() print('Called with args:') print(args) assert args.cfg is not None, 'Missing cfg file' cfg_from_file(args.cfg) # auto-generate exp name if not specified if cfg.EXP_NAME == '': cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}' if args.exp_suffix: cfg.EXP_NAME += f'_{args.exp_suffix}' # auto-generate snapshot path if not specified # if cfg.TEST.SNAPSHOT_DIR == '': cfg.TEST.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME) os.makedirs(cfg.TEST.SNAPSHOT_DIR, exist_ok=True) num_classes = cfg.NUM_CLASSES device = cfg.GPU_ID output_path = osp.join(cfg.TEST.SNAPSHOT_DIR, 'compound_order') if not os.path.exists(output_path): os.makedirs(output_path) print('Using config:') pprint.pprint(cfg) # INIT _init_fn = None if not args.random_train: torch.manual_seed(cfg.TRAIN.RANDOM_SEED) torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED) np.random.seed(cfg.TRAIN.RANDOM_SEED) random.seed(cfg.TRAIN.RANDOM_SEED) def _init_fn(worker_id): np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id) if os.environ.get('ADVENT_DRY_RUN', '0') == '1': return i_iter = 46000 #### Load Discriminator model ##### restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR, f'model_{i_iter}.pth') model_seg = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL) load_checkpoint_for_evaluation(model_seg, restore_from, device) #### Load Discriminator model ##### # restore_from_dis = osp.join(cfg.TEST.SNAPSHOT_DIR, f'model_{i_iter}_D2.pth') # model_dis = get_fc_discriminator(num_classes=num_classes) # load_checkpoint_for_evaluation(model_dis, restore_from_dis, device) print('Model loaded') target_dataset = BDDataSet(root=cfg.DATA_DIRECTORY_TARGET, list_path=cfg.DATA_LIST_TARGET, set=cfg.TRAIN.SET_TARGET, info_path=cfg.TRAIN.INFO_TARGET, crop_size=cfg.TEST.INPUT_SIZE_TARGET, mean=cfg.TRAIN.IMG_MEAN) target_loader = data.DataLoader(target_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_TARGET, num_workers=cfg.NUM_WORKERS, shuffle=True, pin_memory=True, worker_init_fn=_init_fn) with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, 'train_cfg.yml'), 'w') as yaml_file: yaml.dump(cfg, yaml_file, default_flow_style=False) # UDA TRAINING # ranking_target_w_discrim(model_seg, model_dis, target_loader, output_path, cfg) ranking_target_w_discrim(model_seg, target_loader, output_path, cfg)
def main(): # LOAD ARGS args = get_arguments() print('Called with args:') print(args) assert args.cfg is not None, 'Missing cfg file' cfg_from_file(args.cfg) cfg.NUM_WORKERS = args.num_workers if args.option is not None: cfg.TRAIN.OPTION = args.option cfg.TRAIN.LAMBDA_BOUNDARY = args.LAMBDA_BOUNDARY cfg.TRAIN.LAMBDA_DICE = args.LAMBDA_DICE ## gan method settings cfg.GAN = args.gan if cfg.GAN == 'gan': cfg.TRAIN.LAMBDA_ADV_MAIN = 0.001 # GAN elif cfg.GAN == 'lsgan': cfg.TRAIN.LAMBDA_ADV_MAIN = 0.01 # LS-GAN else: raise NotImplementedError(f"Not Supported gan method") ### dataset settings cfg.SOURCE = args.source cfg.TARGET = args.target ## source config if cfg.SOURCE == 'GTA': cfg.DATA_LIST_SOURCE = str(project_root / 'advent/dataset/gta5_list/{}.txt') cfg.DATA_DIRECTORY_SOURCE = str(project_root / 'data/GTA5') cfg.TRAIN.INPUT_SIZE_SOURCE = (1280, 720) elif cfg.SOURCE == 'SYNTHIA': raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset") else: raise NotImplementedError(f"Not yet supported {cfg.SOURCE} dataset") ## target config if cfg.TARGET == 'Cityscapes': cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/cityscapes_list/{}.txt') cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/cityscapes') cfg.EXP_ROOT = project_root / 'experiments_G2C' cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots_G2C') cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs_G2C') cfg.TRAIN.INPUT_SIZE_TARGET = (1024, 512) cfg.TRAIN.INFO_TARGET = str(project_root / 'advent/dataset/cityscapes_list/info.json') cfg.TEST.INPUT_SIZE_TARGET = (1024, 512) cfg.TEST.OUTPUT_SIZE_TARGET = (2048, 1024) cfg.TEST.INFO_TARGET = str(project_root / 'advent/dataset/cityscapes_list/info.json') elif cfg.TARGET == 'BDD': cfg.DATA_LIST_TARGET = str(project_root / 'advent/dataset/compound_list/{}.txt') cfg.DATA_DIRECTORY_TARGET = str(project_root / 'data/bdd/Compound') cfg.EXP_ROOT = project_root / 'experiments' cfg.EXP_ROOT_SNAPSHOT = osp.join(cfg.EXP_ROOT, 'snapshots') cfg.EXP_ROOT_LOGS = osp.join(cfg.EXP_ROOT, 'logs') cfg.TRAIN.INPUT_SIZE_TARGET = (960, 540) cfg.TRAIN.INFO_TARGET = str(project_root / 'advent/dataset/compound_list/info.json') else: raise NotImplementedError(f"Not yet supported {cfg.TARGET} dataset") # auto-generate exp name if not specified if cfg.EXP_NAME == '': cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{cfg.TRAIN.OCDA_METHOD}' if args.exp_suffix: cfg.EXP_NAME += f'_{args.exp_suffix}' # auto-generate snapshot path if not specified if cfg.TRAIN.SNAPSHOT_DIR == '': cfg.TRAIN.SNAPSHOT_DIR = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME) os.makedirs(cfg.TRAIN.SNAPSHOT_DIR, exist_ok=True) # tensorboard if args.tensorboard: if cfg.TRAIN.TENSORBOARD_LOGDIR == '': cfg.TRAIN.TENSORBOARD_LOGDIR = osp.join(cfg.EXP_ROOT_LOGS, 'tensorboard', cfg.EXP_NAME) os.makedirs(cfg.TRAIN.TENSORBOARD_LOGDIR, exist_ok=True) if args.viz_every_iter is not None: cfg.TRAIN.TENSORBOARD_VIZRATE = args.viz_every_iter else: cfg.TRAIN.TENSORBOARD_LOGDIR = '' print('Using config:') pprint.pprint(cfg) # INIT _init_fn = None if not args.random_train: torch.manual_seed(cfg.TRAIN.RANDOM_SEED) torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED) np.random.seed(cfg.TRAIN.RANDOM_SEED) random.seed(cfg.TRAIN.RANDOM_SEED) def _init_fn(worker_id): np.random.seed(cfg.TRAIN.RANDOM_SEED + worker_id) if os.environ.get('ADVENT_DRY_RUN', '0') == '1': return # LOAD SEGMENTATION NET if cfg.TRAIN.MODEL == 'DeepLabv2': model = get_deeplab_v2(num_classes=cfg.NUM_CLASSES, multi_level=cfg.TRAIN.MULTI_LEVEL) saved_state_dict = torch.load(cfg.TRAIN.RESTORE_FROM) if 'DeepLab_resnet_pretrained_imagenet' in cfg.TRAIN.RESTORE_FROM: new_params = model.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') if not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] model.load_state_dict(new_params) else: model.load_state_dict(saved_state_dict) elif cfg.TRAIN.MODEL == 'DeepLabv2_VGG': model = get_deeplab_v2_vgg(cfg=cfg, num_classes=cfg.NUM_CLASSES, pretrained_model=cfg.TRAIN_VGG_PRE_MODEL) if cfg.TRAIN.SELF_TRAINING: path = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.TRAIN.RESTORE_FROM_SELF) saved_state_dict = torch.load(path) model.load_state_dict(saved_state_dict, strict=False) trg_list = cfg.DATA_LIST_TARGET_ORDER print("self-training model loaded: {} ".format(path)) else: trg_list = cfg.DATA_LIST_TARGET else: raise NotImplementedError(f"Not yet supported {cfg.TRAIN.MODEL}") print("model: ") print(model) print('Model loaded') ######## DATALOADERS ######## # GTA5: 24,966: 274,626 / 24,966 = 11 epoch # self-training : target data shuffle shuffle = cfg.TRAIN.SHUFFLE if cfg.TRAIN.SELF_TRAINING: max_iteration = None else: max_iteration = cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_SOURCE source_dataset = GTA5DataSet(root=cfg.DATA_DIRECTORY_SOURCE, list_path=cfg.DATA_LIST_SOURCE, set=cfg.TRAIN.SET_SOURCE, max_iters=max_iteration, crop_size=cfg.TRAIN.INPUT_SIZE_SOURCE, mean=cfg.TRAIN.IMG_MEAN) source_loader = data.DataLoader(source_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_SOURCE, num_workers=cfg.NUM_WORKERS, shuffle=True, pin_memory=True, worker_init_fn=_init_fn) if cfg.TARGET == "BDD": # GTA5: 14,697: 264,546 / 14,697 = 18 epoch target_dataset = BDDdataset(root=cfg.DATA_DIRECTORY_TARGET, list_path=trg_list, set=cfg.TRAIN.SET_TARGET, info_path=cfg.TRAIN.INFO_TARGET, max_iters=max_iteration, crop_size=cfg.TRAIN.INPUT_SIZE_TARGET, mean=cfg.TRAIN.IMG_MEAN) target_loader = data.DataLoader(target_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_TARGET, num_workers=cfg.NUM_WORKERS, shuffle=shuffle, pin_memory=True, worker_init_fn=_init_fn) elif cfg.TARGET == 'Cityscapes': target_dataset = CityscapesDataSet( root=cfg.DATA_DIRECTORY_TARGET, list_path=cfg.DATA_LIST_TARGET, set=cfg.TRAIN.SET_TARGET, info_path=cfg.TRAIN.INFO_TARGET, max_iters=cfg.TRAIN.MAX_ITERS * cfg.TRAIN.BATCH_SIZE_TARGET, crop_size=cfg.TRAIN.INPUT_SIZE_TARGET, mean=cfg.TRAIN.IMG_MEAN) target_loader = data.DataLoader(target_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_TARGET, num_workers=cfg.NUM_WORKERS, shuffle=True, pin_memory=True, worker_init_fn=_init_fn) else: raise NotImplementedError(f"Not yet supported {cfg.TARGET} datasets") with open(osp.join(cfg.TRAIN.SNAPSHOT_DIR, 'train_cfg.yml'), 'w') as yaml_file: yaml.dump(cfg, yaml_file, default_flow_style=False) # UDA TRAINING train_domain_adaptation(model, source_loader, target_loader, cfg)