def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True): self._network_pkl = network_pkl self._dataset_args = dataset_args self._mirror_augment = mirror_augment self._results = [] if (dataset_args is None or mirror_augment is None) and run_dir is not None: run_config = misc.parse_config_for_previous_run(run_dir) self._dataset_args = dict(run_config['dataset']) self._dataset_args['shuffle_mb'] = 0 self._mirror_augment = run_config['train'].get('mirror_augment', False) time_begin = time.time() with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager _G, _D, Gs = misc.load_pkl(self._network_pkl) self._evaluate(Gs, num_gpus=num_gpus) self._eval_time = time.time() - time_begin if log_results: result_str = self.get_result_str() if run_dir is not None: log = os.path.join(run_dir, 'metric-%s.txt' % self.name) with dnnlib.util.Logger(log, 'a'): print(result_str) else: print(result_str)
def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True, Gs_kwargs=dict(is_validation=True)): self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment) time_begin = time.time() with tf.Graph().as_default(), tflib.create_session( tf_config).as_default(): # pylint: disable=not-context-manager self._report_progress(0, 1) # _G, _D, Gs = misc.load_pkl(self._network_pkl) _G, _D, _I, Gs = misc.load_pkl(self._network_pkl) self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus) self._report_progress(1, 1) self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init if log_results: if run_dir is not None: log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name) with dnnlib.util.Logger(log_file, 'a'): print(self.get_result_str().strip()) else: print(self.get_result_str().strip())
def run(self, network_pkl, num_gpus=1, G_kwargs=dict(is_validation=True)): self._results = [] self._network_name = os.path.splitext(os.path.basename(network_pkl))[0] self._eval_time = 0 self._dataset = None with tf.Graph().as_default(), tflib.create_session().as_default(): # pylint: disable=not-context-manager self._report_progress(0, 1) time_begin = time.time() with dnnlib.util.open_url(network_pkl) as f: G, D, Gs = pickle.load(f) G_kwargs = dnnlib.EasyDict(G_kwargs) G_kwargs.update(self.force_G_kwargs) self._evaluate(G=G, D=D, Gs=Gs, G_kwargs=G_kwargs, num_gpus=num_gpus) self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init self._report_progress(1, 1) if self._dataset is not None: self._dataset.close() self._dataset = None result_str = self.get_result_str() tqdm.write(result_str) if self._run_dir is not None and os.path.isdir(self._run_dir): with open(os.path.join(self._run_dir, f'metric-{self.name}.txt'), 'at') as f: f.write(result_str + '\n')
def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True, model_type="rignet"): create_dir(config.EVALUATION_DIR, exist_ok=True) self._network_pkl = network_pkl self._dataset_args = dataset_args self._mirror_augment = mirror_augment self._results = [] self.model_type = model_type if (dataset_args is None or mirror_augment is None) and run_dir is not None: run_config = misc.parse_config_for_previous_run(run_dir) self._dataset_args = dict(run_config['dataset']) self._dataset_args['shuffle_mb'] = 0 self._mirror_augment = run_config['train'].get( 'mirror_augment', False) time_begin = time.time() with tf.Graph().as_default(), tflib.create_session( tf_config).as_default(): # pylint: disable=not-context-manager E, _G, _D, Gs = misc.load_pkl(self._network_pkl) print("Loaded Encoder") Inv, _, _, _ = misc.load_pkl(config.INVERSION_PICKLE_DIR) print("Loaded Inv") self._evaluate(Gs, E, Inv, num_gpus=num_gpus) self._eval_time = time.time() - time_begin if log_results: result_str = self.get_result_str() if run_dir is not None: log = os.path.join(run_dir, 'metric-%s.txt' % self.name) with dnnlib.util.Logger(log, 'a'): print(result_str) else: print(result_str) result_path = os.path.join( config.EVALUATION_DIR, "result_" + convert_pickle_path_to_name(self._network_pkl) + ".txt") write_to_file(result_str + "\n\n\n", result_path)
def run(self, network_pkl, num_imgs, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, ratio=1.0, num_gpus=1, tf_config=None, log_results=True, Gs_kwargs=dict(is_validation=True), eval_mod=False, **kwargs): self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment) self.eval_mod = eval_mod time_begin = time.time() with tf.Graph().as_default(), tflib.create_session( tf_config).as_default(): self._report_progress(0, 1) _G = _D = Gs = None if self._network_pkl is not None: _G, _D, Gs = misc.load_pkl(self._network_pkl)[:3] self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus, num_imgs=num_imgs, ratio=ratio, **kwargs) self._report_progress(1, 1) self._eval_time = time.time() - time_begin if log_results: if run_dir is not None: log_file = os.path.join(run_dir, "metric-%s.txt" % self.name) with dnnlib.util.Logger(log_file, "a", screen=False): print(self.get_result_str().strip()) print(self.get_result_str(screen=True).strip()) return self._results[0].value
def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True, include_I=False, avg_mv_for_I=False, Gs_kwargs=dict(is_validation=True, return_atts=False), train_infernet=False, is_vae=False, use_D=False, **kwargs): self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment) time_begin = time.time() with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager self._report_progress(0, 1) if include_I: if avg_mv_for_I: _G, _D, _I, Gs, I = misc.load_pkl(self._network_pkl) else: _G, _D, I, Gs = misc.load_pkl(self._network_pkl) outs = self._evaluate(Gs=Gs, Gs_kwargs=Gs_kwargs, I_net=I, num_gpus=num_gpus, **kwargs) elif train_infernet: I, Gs = misc.load_pkl(self._network_pkl) outs = self._evaluate(Gs=Gs, Gs_kwargs=Gs_kwargs, I_net=I, num_gpus=num_gpus, **kwargs) elif is_vae: if use_D: I, Gs, D = misc.load_pkl(self._network_pkl) else: I, Gs = misc.load_pkl(self._network_pkl) outs = self._evaluate(Gs=Gs, Gs_kwargs=Gs_kwargs, I_net=I, num_gpus=num_gpus, **kwargs) else: _G, _D, Gs = misc.load_pkl(self._network_pkl) outs = self._evaluate(Gs=Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus, **kwargs) self._report_progress(1, 1) self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init if log_results: if run_dir is not None: log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name) with dnnlib.util.Logger(log_file, 'a'): print(self.get_result_str().strip()) else: print(self.get_result_str().strip()) return outs
def setup_training_options( # General options (not included in desc). gpus=None, # Number of GPUs: <int>, default = 1 gpu snap=None, # Snapshot interval: <int>, default = 50 ticks # Training dataset. data=None, # Training dataset (required): <path> res=None, # Override dataset resolution: <int>, default = highest available mirror=None, # Augment dataset with x-flips: <bool>, default = False mirrory=None, # Augment dataset with y-flips: <bool>, default = False use_raw=None, # Metrics (not included in desc). metrics=None, # List of metric names: [], ['fid50k_full'] (default), ... metricdata=None, # Metric dataset (optional): <path> # Base config. cfg=None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline' gamma=None, # Override R1 gamma: <float>, default = depends on cfg kimg=None, # Override training duration: <int>, default = depends on cfg # Discriminator augmentation. aug=None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed', 'adarv' p=None, # Specify p for 'fixed' (required): <float> target=None, # Override ADA target for 'ada' and 'adarv': <float>, default = depends on aug augpipe=None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' # Comparison methods. cmethod=None, # Comparison method: 'nocmethod' (default), 'bcr', 'zcr', 'pagan', 'wgangp', 'auxrot', 'spectralnorm', 'shallowmap', 'adropout' dcap=None, # Multiplier for discriminator capacity: <float>, default = 1 # Transfer learning. resume=None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url> freezed=None, # Freeze-D: <int>, default = 0 discriminator layers ): # Initialize dicts. args = dnnlib.EasyDict() args.G_args = dnnlib.EasyDict(func_name='training.networks.G_main') args.D_args = dnnlib.EasyDict(func_name='training.networks.D_main') args.G_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99) args.D_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99) args.loss_args = dnnlib.EasyDict(func_name='training.loss.stylegan2') args.augment_args = dnnlib.EasyDict( class_name='training.augment.AdaptiveAugment') # --------------------------- # General options: gpus, snap # --------------------------- if gpus is None: gpus = 1 assert isinstance(gpus, int) if not (gpus >= 1 and gpus & (gpus - 1) == 0): raise UserError('--gpus must be a power of two') args.num_gpus = gpus if snap is None: snap = 50 assert isinstance(snap, int) if snap < 1: raise UserError('--snap must be at least 1') args.image_snapshot_ticks = snap args.network_snapshot_ticks = snap # --------------------------------------------- # Training dataset: data, res, mirror, mirrory # --------------------------------------------- assert data is not None assert isinstance(data, str) data_name = os.path.basename(os.path.abspath(data)) if not os.path.isdir(data) or len(data_name) == 0: raise UserError( '--data must point to a directory containing *.tfrecords') desc = data_name with tf.Graph().as_default(), tflib.create_session().as_default(): # pylint: disable=not-context-manager args.train_dataset_args = dnnlib.EasyDict(path=data, max_label_size='full') args.train_dataset_args.use_raw = use_raw dataset_obj = dataset.load_dataset( **args.train_dataset_args ) # try to load the data and see what comes out args.train_dataset_args.resolution = dataset_obj.shape[ -1] # be explicit about resolution args.train_dataset_args.max_label_size = dataset_obj.label_size # be explicit about label size validation_set_available = dataset_obj.has_validation_set dataset_obj.close() dataset_obj = None if res is None: res = args.train_dataset_args.resolution else: assert isinstance(res, int) if not (res >= 4 and res & (res - 1) == 0): raise UserError('--res must be a power of two and at least 4') if res > args.train_dataset_args.resolution: raise UserError( f'--res cannot exceed maximum available resolution in the dataset ({args.train_dataset_args.resolution})' ) desc += f'-res{res:d}' args.train_dataset_args.resolution = res if mirror is None: mirror = False assert isinstance(mirror, bool) if mirror: desc += '-mirror' args.train_dataset_args.mirror_augment = mirror if mirrory is None: mirrory = False assert isinstance(mirrory, bool) if mirrory: desc += '-mirrory' args.train_dataset_args.mirror_augment_v = mirrory args.train_dataset_args.use_raw = use_raw # ---------------------------- # Metrics: metrics, metricdata # ---------------------------- if metrics is None: metrics = ['fid50k_full'] assert isinstance(metrics, list) assert all(isinstance(metric, str) for metric in metrics) args.metric_arg_list = [] for metric in metrics: if metric not in metric_defaults.metric_defaults: raise UserError('\n'.join( ['--metrics can only contain the following values:', 'none'] + list(metric_defaults.metric_defaults.keys()))) args.metric_arg_list.append(metric_defaults.metric_defaults[metric]) args.metric_dataset_args = dnnlib.EasyDict(args.train_dataset_args) if metricdata is not None: assert isinstance(metricdata, str) if not os.path.isdir(metricdata): raise UserError( '--metricdata must point to a directory containing *.tfrecords' ) args.metric_dataset_args.path = metricdata # ----------------------------- # Base config: cfg, gamma, kimg # ----------------------------- if cfg is None: cfg = 'auto' assert isinstance(cfg, str) desc += f'-{cfg}' cfg_specs = { 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=8), # populated dynamically based on 'gpus' and 'res' '11gb-gpu': dict(ref_gpus=1, kimg=25000, mb=4, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU '11gb-gpu-complex': dict(ref_gpus=1, kimg=25000, mb=4, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 11GB GPU '24gb-gpu': dict(ref_gpus=1, kimg=25000, mb=8, mbstd=8, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 24GB GPU '24gb-gpu-complex': dict(ref_gpus=1, kimg=25000, mb=8, mbstd=8, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 24GB GPU '48gb-gpu': dict(ref_gpus=1, kimg=25000, mb=16, mbstd=16, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, 48GB GPU 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, unlike original StyleGAN2 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=0.5, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), 'cifarbaseline': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=0.5, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=8), } assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) if cfg == 'auto': desc += f'{gpus:d}' spec.ref_gpus = gpus spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay spec.mbstd = min( spec.mb // gpus, 4 ) # other hyperparams behave more predictably if mbstd group size remains fixed spec.fmaps = 1 if res >= 512 else 0.5 spec.lrate = 0.002 if res >= 1024 else 0.0025 spec.gamma = 0.0002 * (res**2) / spec.mb # heuristic formula spec.ema = spec.mb * 10 / 32 args.total_kimg = spec.kimg args.minibatch_size = spec.mb args.minibatch_gpu = spec.mb // spec.ref_gpus args.D_args.mbstd_group_size = spec.mbstd args.G_args.fmap_base = args.D_args.fmap_base = int(spec.fmaps * 16384) args.G_args.fmap_max = args.D_args.fmap_max = 512 args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = spec.lrate args.loss_args.r1_gamma = spec.gamma args.G_smoothing_kimg = spec.ema args.G_smoothing_rampup = spec.ramp args.G_args.mapping_layers = spec.map args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4 # enable mixed-precision training args.G_args.conv_clamp = args.D_args.conv_clamp = 256 # clamp activations to avoid float16 overflow if cfg == 'cifar' or cfg.split('-')[-1] == 'complex': args.loss_args.pl_weight = 0 # disable path length regularization args.G_args.style_mixing_prob = None # disable style mixing args.D_args.architecture = 'orig' # disable residual skip connections if gamma is not None: assert isinstance(gamma, float) if not gamma >= 0: raise UserError('--gamma must be non-negative') desc += f'-gamma{gamma:g}' args.loss_args.r1_gamma = gamma if kimg is not None: assert isinstance(kimg, int) if not kimg >= 1: raise UserError('--kimg must be at least 1') desc += f'-kimg{kimg:d}' args.total_kimg = kimg # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- if aug is None: aug = 'ada' else: assert isinstance(aug, str) desc += f'-{aug}' if aug == 'ada': args.augment_args.tune_heuristic = 'rt' args.augment_args.tune_target = 0.6 elif aug == 'noaug': pass elif aug == 'fixed': if p is None: raise UserError(f'--aug={aug} requires specifying --p') elif aug == 'adarv': if not validation_set_available: raise UserError( f'--aug={aug} requires separate validation set; please see "python dataset_tool.py pack -h"' ) args.augment_args.tune_heuristic = 'rv' args.augment_args.tune_target = 0.5 else: raise UserError(f'--aug={aug} not supported') if p is not None: assert isinstance(p, float) if aug != 'fixed': raise UserError('--p can only be specified with --aug=fixed') if not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += f'-p{p:g}' args.augment_args.initial_strength = p if target is not None: assert isinstance(target, float) if aug not in ['ada', 'adarv']: raise UserError( '--target can only be specified with --aug=ada or --aug=adarv') if not 0 <= target <= 1: raise UserError('--target must be between 0 and 1') desc += f'-target{target:g}' args.augment_args.tune_target = target assert augpipe is None or isinstance(augpipe, str) if augpipe is None: augpipe = 'bgc' else: if aug == 'noaug': raise UserError('--augpipe cannot be specified with --aug=noaug') desc += f'-{augpipe}' augpipe_specs = { 'blit': dict(xflip=1, rotate90=1, xint=1), 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'filter': dict(imgfilter=1), 'noise': dict(noise=1), 'cutout': dict(cutout=1), 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), } assert augpipe in augpipe_specs if aug != 'noaug': args.augment_args.apply_func = 'training.augment.augment_pipeline' args.augment_args.apply_args = augpipe_specs[augpipe] # --------------------------------- # Comparison methods: cmethod, dcap # --------------------------------- assert cmethod is None or isinstance(cmethod, str) if cmethod is None: cmethod = 'nocmethod' else: desc += f'-{cmethod}' if cmethod == 'nocmethod': pass elif cmethod == 'bcr': args.loss_args.func_name = 'training.loss.cmethods' args.loss_args.bcr_real_weight = 10 args.loss_args.bcr_fake_weight = 10 args.loss_args.bcr_augment = dnnlib.EasyDict( func_name='training.augment.augment_pipeline', xint=1, xint_max=1 / 32) elif cmethod == 'zcr': args.loss_args.func_name = 'training.loss.cmethods' args.loss_args.zcr_gen_weight = 0.02 args.loss_args.zcr_dis_weight = 0.2 args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0 # disable mixed-precision training args.G_args.conv_clamp = args.D_args.conv_clamp = None elif cmethod == 'pagan': if aug != 'noaug': raise UserError( f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug' ) args.D_args.use_pagan = True args.augment_args.tune_heuristic = 'rt' # enable ada heuristic args.augment_args.pop('apply_func', None) # disable discriminator augmentation args.augment_args.pop('apply_args', None) args.augment_args.tune_target = 0.95 elif cmethod == 'wgangp': if aug != 'noaug': raise UserError( f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug' ) if gamma is not None: raise UserError( f'--cmethod={cmethod} is not compatible with --gamma') args.loss_args = dnnlib.EasyDict(func_name='training.loss.wgangp') args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = 0.001 args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0 # disable mixed-precision training args.G_args.conv_clamp = args.D_args.conv_clamp = None args.lazy_regularization = False elif cmethod == 'auxrot': if args.train_dataset_args.max_label_size > 0: raise UserError( f'--cmethod={cmethod} is not compatible with label conditioning; please specify a dataset without labels' ) args.loss_args.func_name = 'training.loss.cmethods' args.loss_args.auxrot_alpha = 10 args.loss_args.auxrot_beta = 5 args.D_args.score_max = 5 # prepare D to output 5 scalars per image instead of just 1 elif cmethod == 'spectralnorm': args.D_args.use_spectral_norm = True elif cmethod == 'shallowmap': if args.G_args.mapping_layers == 2: raise UserError(f'--cmethod={cmethod} is a no-op for --cfg={cfg}') args.G_args.mapping_layers = 2 elif cmethod == 'adropout': if aug != 'noaug': raise UserError( f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug' ) args.D_args.adaptive_dropout = 1 args.augment_args.tune_heuristic = 'rt' # enable ada heuristic args.augment_args.pop('apply_func', None) # disable discriminator augmentation args.augment_args.pop('apply_args', None) args.augment_args.tune_target = 0.6 else: raise UserError(f'--cmethod={cmethod} not supported') if dcap is not None: assert isinstance(dcap, float) if not dcap > 0: raise UserError('--dcap must be positive') desc += f'-dcap{dcap:g}' args.D_args.fmap_base = max(int(args.D_args.fmap_base * dcap), 1) args.D_args.fmap_max = max(int(args.D_args.fmap_max * dcap), 1) # ---------------------------------- # Transfer learning: resume, freezed # ---------------------------------- resume_specs = { 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', 'afhqcat512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqcat.pkl', 'afhqdog512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqdog.pkl', 'afhqwild512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqwild.pkl', 'brecahad512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/brecahad.pkl', 'cifar10': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl', 'metfaces': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl', } assert resume is None or isinstance(resume, str) if resume is None: resume = 'noresume' elif resume == 'noresume': desc += '-noresume' elif resume in resume_specs: desc += f'-resume{resume}' args.resume_pkl = resume_specs[resume] # predefined url else: desc += '-resumecustom' args.resume_pkl = resume # custom path or url if resume != 'noresume': args.augment_args.tune_kimg = 100 # make ADA react faster at the beginning args.G_smoothing_rampup = None # disable EMA rampup if freezed is not None: assert isinstance(freezed, int) if not freezed >= 0: raise UserError('--freezed must be non-negative') desc += f'-freezed{freezed:d}' args.D_args.freeze_layers = freezed return desc, args
def load_network(self, network_path: str) -> None: self.session = tflib.create_session(None, force_as_default=True) # Load Networks self.image_resolution = '256' replace_res = re.findall('256|512', network_path) if replace_res: self.image_resolution = int(replace_res[0]) self.current_network_checkpoint = network_path _, _, self.Gs = misc.load_pkl('../../results/' + network_path) self._vgg16_model = misc.load_pkl( 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/vgg16_zhang_perceptual.pkl' ) # Define Placeholders self.latent_placeholder = tf.placeholder(tf.float32, shape=self.Gs.input_shapes[0]) self.label_placeholder = tf.placeholder(tf.float32, shape=self.Gs.input_shapes[1]) self.dlatent_placeholder = tf.placeholder( tf.float32, shape=self.Gs.components.synthesis.input_shape) self.dlatent_left_placeholder = tf.placeholder( tf.float32, self.Gs.components.synthesis.input_shape) self.dlatent_right_placeholder = tf.placeholder( tf.float32, self.Gs.components.synthesis.input_shape) self.magnitude_placeholder = tf.placeholder(tf.float32, shape=()) self.delta_magnitude_placeholder = tf.placeholder(tf.float32, shape=()) if self.Gs.input_shapes[1][1] == 121: self.label_version = 'v7' if self.Gs.input_shapes[1][1] == 127: self.label_version = 'v5' read_label_version = re.findall('v\d+', network_path) if read_label_version: self.label_version = read_label_version[0] # Generate Image self._output_image = self.Gs.get_output_for(self.latent_placeholder, self.label_placeholder, randomize_noise=False, truncation_psi=None, style_mixing_prob=None) self._synthesis = self.Gs.components.synthesis.get_output_for( self.dlatent_placeholder, randomize_noise=False) if 'baseline_without_labels' in self.current_network_checkpoint: zero_label = tf.zeros(shape=[1, 0], dtype=tf.float32) self._mapping_without_labels = self.Gs.components.mapping.get_output_for( self.latent_placeholder, zero_label) # Separate Mapping if 'mapping' in self.Gs.components: self.separate_mapping = False self._mapping = self.Gs.components.mapping.get_output_for( self.latent_placeholder, self.label_placeholder) else: self.separate_mapping = True self.concat = self.Gs.components.synthesis.input_shape[2] > 512 self._latent_separate_mapping = self.Gs.components.mapping_latent.get_output_for( self.latent_placeholder) self._label_separate_mapping = self.Gs.components.mapping_label.get_output_for( self.label_placeholder) # Interpolation Gradient Graph dlatent_int_delta, dlatent_int = utils.delta_lerp( a=self.dlatent_left_placeholder, b=self.dlatent_right_placeholder, mag=self.magnitude_placeholder, delta_mag=self.delta_magnitude_placeholder, ) image_int_delta = self.Gs.components.synthesis.get_output_for( dlatent_int_delta, randomize_noise=False) self._gradients = tf.gradients(image_int_delta, [dlatent_int])[0] # VGG16 Graph image_int = self.Gs.components.synthesis.get_output_for( dlatent_int, randomize_noise=False) self._vgg16_distance = self._vgg16_model.get_output_for( image_int, image_int_delta) print('Successfully loaded:', self.current_network_checkpoint)
def run(data, train_dir, config, d_aug, diffaug_policy, cond, ops, mirror, mirror_v, \ kimg, batch_size, lrate, resume, resume_kimg, num_gpus, ema_kimg, gamma, freezeD): # training functions if d_aug: # https://github.com/mit-han-lab/data-efficient-gans train = EasyDict( run_func_name='training.training_loop_diffaug.training_loop' ) # Options for training loop (Diff Augment method) loss_args = EasyDict( func_name='training.loss_diffaug.ns_DiffAugment_r1', policy=diffaug_policy) # Options for loss (Diff Augment method) else: # original nvidia train = EasyDict(run_func_name='training.training_loop.training_loop' ) # Options for training loop (original from NVidia) G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg' ) # Options for generator loss. D_loss = EasyDict(func_name='training.loss.D_logistic_r1' ) # Options for discriminator loss. # network functions G = EasyDict(func_name='training.networks_stylegan2.G_main' ) # Options for generator network. D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2' ) # Options for discriminator network. G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. sched = EasyDict() # Options for TrainingSchedule. grid = EasyDict( size='1080p', layout='random') # Options for setup_snapshot_image_grid(). sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). G.impl = D.impl = ops # dataset (tfrecords) - get or create tfr_files = file_list(os.path.dirname(data), 'tfr') tfr_files = [ f for f in tfr_files if basename(data) == basename(f).split('-')[0] ] if len(tfr_files) == 0 or os.stat(tfr_files[0]).st_size == 0: tfr_file, total_samples = create_from_image_folders( data) if cond is True else create_from_images(data) else: tfr_file = tfr_files[0] dataset_args = EasyDict(tfrecord=tfr_file) # resolutions with tf.Graph().as_default(), tflib.create_session().as_default(): # pylint: disable=not-context-manager dataset_obj = dataset.load_dataset( **dataset_args) # loading the data to see what comes out resolution = dataset_obj.resolution init_res = dataset_obj.init_res res_log2 = dataset_obj.res_log2 dataset_obj.close() dataset_obj = None if list(init_res) == [4, 4]: desc = '%s-%d' % (basename(data), resolution) else: print(' custom init resolution', init_res) desc = basename(tfr_file) G.init_res = D.init_res = list(init_res) train.savenames = [desc.replace(basename(data), 'snapshot'), desc] desc += '-%s' % config # training schedule train.total_kimg = kimg train.image_snapshot_ticks = 1 * num_gpus if kimg <= 1000 else 4 * num_gpus train.network_snapshot_ticks = 5 train.mirror_augment = mirror train.mirror_augment_v = mirror_v sched.tick_kimg_base = 2 if train.total_kimg < 2000 else 4 # learning rate if config == 'e': sched.G_lrate_base = 0.001 sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003} sched.lrate_step = 1500 # period for stepping to next lrate, in kimg if config == 'f': sched.G_lrate_base = lrate # 0.001 for big datasets, 0.0003 for few-shot sched.D_lrate_base = sched.G_lrate_base # *2 - not used anyway # batch size (for 16gb memory GPU) sched.minibatch_gpu_base = 4096 // resolution if batch_size is None else batch_size print(' Batch size', sched.minibatch_gpu_base) sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base sc.num_gpus = num_gpus if config == 'e': G.fmap_base = D.fmap_base = 8 << 10 if d_aug: loss_args.gamma = 100 if gamma is None else gamma else: D_loss.gamma = 100 if gamma is None else gamma elif config == 'f': G.fmap_base = D.fmap_base = 16 << 10 else: print(' Only configs E and F are implemented') exit() if cond: desc += '-cond' dataset_args.max_label_size = 'full' # conditioned on full label if freezeD: D.freezeD = True train.resume_with_new_nets = True if d_aug: desc += '-daug' sc.submit_target = dnnlib.SubmitTarget.LOCAL sc.local.do_not_copy_source_files = True kwargs = EasyDict(train) kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt) kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, tf_config=tf_config) kwargs.update(resume_pkl=resume, resume_kimg=resume_kimg, resume_with_new_nets=True) if ema_kimg is not None: kwargs.update(G_ema_kimg=ema_kimg) if d_aug: kwargs.update(loss_args=loss_args) else: kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss) kwargs.submit_config = copy.deepcopy(sc) kwargs.submit_config.run_dir_root = train_dir kwargs.submit_config.run_desc = desc dnnlib.submit_run(**kwargs)
def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True, num_repeats=1, Gs_kwargs=dict(is_validation=True), resume_with_new_nets=False, truncations=[None]): self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment) with tf.Graph().as_default(), tflib.create_session( tf_config).as_default(): # pylint: disable=not-context-manager self._report_progress(0, 1) _G, _D, Gs = misc.load_pkl(self._network_pkl) if resume_with_new_nets: dataset = self._get_dataset_obj() G = dnnlib.tflib.Network( 'G', num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, func_name='training.co_mod_gan.G_main', pix2pix=dataset.pix2pix) Gs_new = G.clone('Gs') Gs_new.copy_vars_from(Gs) Gs = Gs_new for t in truncations: print('truncation={}'.format(t)) self._results = [] time_begin = time.time() Gs_kwargs.update(truncation_psi_val=t) self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus) self._report_progress(1, 1) if num_repeats > 1: records = [ dnnlib.EasyDict(value=[res.value], suffix=res.suffix, fmt=res.fmt) for res in self._results ] for i in range(1, num_repeats): print(self.get_result_str().strip()) self._results = [] self._report_progress(0, 1) self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus) self._report_progress(1, 1) for rec, res in zip(records, self._results): rec.value.append(res.value) self._results = [] for rec in records: self._report_result(np.mean(rec.value), rec.suffix, rec.fmt) self._report_result(np.std(rec.value), rec.suffix + '-std', rec.fmt) self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init if log_results: if run_dir is not None: log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name) with dnnlib.util.Logger(log_file, 'a'): print(self.get_result_str().strip()) else: print(self.get_result_str().strip())
from training import dataset import numpy as np import tensorflow as tf import training.misc as misc import matplotlib.pyplot as plt tfrecord_dir = '../../datasets/cars_v5_512' tflib.init_tf({'gpu_options.allow_growth': True}) training_set = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0) tflib.init_uninitialized_vars() session = tflib.create_session(None, force_as_default=True) latent_placeholder = tf.placeholder(tf.float32, shape=(None, 512)) dlatent_placeholder = tf.placeholder(tf.float32, shape=(None, 16, 512)) label_placeholder = tf.placeholder(tf.float32, shape=(None, 127)) G, D, Gs = misc.load_pkl( '../../results/00155-stylegan2-cars_v5_512-2gpu-config-f/network-snapshot-010467.pkl' ) num_steps = 10 label_left = training_set.get_random_labels_np(1) rotation_offset = 108 rotations = label_left[:, rotation_offset:rotation_offset + 8] rotation_index = np.argmax(rotations, axis=1) new_rotation_index = ((rotation_index + np.random.choice([-1, 1])) % 8) new_rotation = np.zeros([8], dtype=np.uint32)