def get_dataset(args): assert args.data_dir, '--data_dir has to be specified.' dataset = None if args.data_type == 'image': height, width = [ shape * 2**(len(args.d_channels or args.channels) - 1) for shape in args.base_shape ] dataset = utils.ImageFolder(args.data_dir, mirror=args.mirror_augment, pixel_min=args.pixel_min, pixel_max=args.pixel_max, height=height, width=width, resize=args.data_resize, grayscale=args.data_channels == 1) elif args.data_type == 'lmdb': dataset = utils.LmdbDataset( args.data_dir, mirror=args.mirror_augment, pixel_min=args.pixel_min, pixel_max=args.pixel_max, resolution=args.resolution, ) assert len(dataset), 'No images found at {}'.format(args.data_dir) return dataset
def eval_fid(G, prior_generator, args): assert args.data_dir, '--data_dir has to be specified.' dataset = utils.ImageFolder(args.data_dir, pixel_min=args.pixel_min, pixel_max=args.pixel_max) assert len(dataset), 'No images found at {}'.format(args.data_dir) inception = stylegan2.external_models.inception.InceptionV3FeatureExtractor( pixel_min=args.pixel_min, pixel_max=args.pixel_max) if len(args.gpu) > 1: inception = torch.nn.DataParallel(inception, device_ids=args.gpu) args.reals_batch_size = max(args.reals_batch_size, len(args.gpu)) fid = stylegan2.metrics.fid.FID(G=G, prior_generator=prior_generator, dataset=dataset, num_samples=args.num_samples, fid_model=inception, fid_size=args.size, truncation_psi=args.truncation_psi, reals_batch_size=args.reals_batch_size, reals_data_workers=args.reals_data_workers) value = fid.evaluate() name = 'FID' if args.size: name += '({})'.format(args.size) if args.truncation_psi != 1: name += 'trunc{}'.format(args.truncation_psi) name += ':{}k'.format(args.num_samples // 1000) _report_metric(value, name, args)
def get_dataset(args): assert args.data_dir, '--data_dir has to be specified.' height, width = [ shape * 2**(len(args.d_channels or args.channels) - 1) for shape in args.base_shape ] # 4 * 2**(n-1) dataset = utils.ImageFolder( args.data_dir, mirror=args.mirror_augment, pixel_min=args.pixel_min, # -1 pixel_max=args.pixel_max, # 1 height=height, width=width, resize=args.data_resize, # True grayscale=args.data_channels == 1) assert len(dataset), 'No images found at {}'.format(args.data_dir) return dataset
def project_real_images(G, args): device = torch.device(args.gpu[0] if args.gpu else 'cpu') print('Loading images from "%s"...' % args.data_dir) dataset = utils.ImageFolder( args.data_dir, pixel_min=args.pixel_min, pixel_max=args.pixel_max) rnd = np.random.RandomState(args.seed) indices = rnd.choice( len(dataset), size=min(args.num_images, len(dataset)), replace=False) images = [] for i in indices: data = dataset[i] if isinstance(data, (tuple, list)): data = data[0] images.append(data) images = torch.stack(images).to(device) name_prefix = ['image%04d-' % i for i in indices] print('Done!') project_images(G, images, name_prefix, args)