log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) # Declare an image generator generator = OcclusionAwareGenerator( **config['model_params']['generator_params'], **config['model_params']['common_params']) # If GPU Available, adapt to it if torch.cuda.is_available(): generator.to(opt.device_ids[0]) if opt.verbose: print(generator) # Declare a discriminator discriminator = MultiScaleDiscriminator( **config['model_params']['discriminator_params'], **config['model_params']['common_params']) if torch.cuda.is_available(): discriminator.to(opt.device_ids[0]) if opt.verbose: print(discriminator) # Declare a key point detector kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) if torch.cuda.is_available(): kp_detector.to(opt.device_ids[0]) # Print network details if using --verbose flag if opt.verbose:
parser.add_argument("--save_dir", default='/home/aistudio/train_ckpt', help="path to save in") parser.add_argument("--preload", action='store_true', help="preload dataset to RAM") parser.set_defaults(verbose=False) opt = parser.parse_args() with open(opt.config) as f: config = yaml.load(f) generator = OcclusionAwareGenerator( **config['model_params']['generator_params'], **config['model_params']['common_params']) discriminator = MultiScaleDiscriminator( **config['model_params']['discriminator_params'], **config['model_params']['common_params']) kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params']) if opt.preload: logging.info('PreLoad Dataset: Start') pre_list = list(range(len(dataset))) import multiprocessing.pool as pool with pool.Pool(4) as pl: buf = pl.map(dataset.preload, pre_list) for idx, (i, v) in enumerate(zip(pre_list, buf)): dataset.buffed[i] = v.copy() buf[idx] = None