示例#1
0
    def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True):
        self._network_pkl = network_pkl
        self._dataset_args = dataset_args
        self._mirror_augment = mirror_augment
        self._results = []

        if (dataset_args is None or mirror_augment is None) and run_dir is not None:
            run_config = misc.parse_config_for_previous_run(run_dir)
            self._dataset_args = dict(run_config['dataset'])
            self._dataset_args['shuffle_mb'] = 0
            self._mirror_augment = run_config['train'].get('mirror_augment', False)

        time_begin = time.time()
        with tf.Graph().as_default(), tflib.create_session(tf_config).as_default():  # pylint: disable=not-context-manager
            _G, _D, Gs = misc.load_pkl(self._network_pkl)
            self._evaluate(Gs, num_gpus=num_gpus)
        self._eval_time = time.time() - time_begin

        if log_results:
            result_str = self.get_result_str()
            if run_dir is not None:
                log = os.path.join(run_dir, 'metric-%s.txt' % self.name)
                with dnnlib.util.Logger(log, 'a'):
                    print(result_str)
            else:
                print(result_str)
示例#2
0
    def run(self,
            network_pkl,
            run_dir=None,
            data_dir=None,
            dataset_args=None,
            mirror_augment=None,
            num_gpus=1,
            tf_config=None,
            log_results=True,
            Gs_kwargs=dict(is_validation=True)):
        self._reset(network_pkl=network_pkl,
                    run_dir=run_dir,
                    data_dir=data_dir,
                    dataset_args=dataset_args,
                    mirror_augment=mirror_augment)
        time_begin = time.time()
        with tf.Graph().as_default(), tflib.create_session(
                tf_config).as_default():  # pylint: disable=not-context-manager
            self._report_progress(0, 1)
            # _G, _D, Gs = misc.load_pkl(self._network_pkl)
            _G, _D, _I, Gs = misc.load_pkl(self._network_pkl)
            self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus)
            self._report_progress(1, 1)
        self._eval_time = time.time() - time_begin  # pylint: disable=attribute-defined-outside-init

        if log_results:
            if run_dir is not None:
                log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name)
                with dnnlib.util.Logger(log_file, 'a'):
                    print(self.get_result_str().strip())
            else:
                print(self.get_result_str().strip())
示例#3
0
    def run(self, network_pkl, num_gpus=1, G_kwargs=dict(is_validation=True)):
        self._results = []
        self._network_name = os.path.splitext(os.path.basename(network_pkl))[0]
        self._eval_time = 0
        self._dataset = None

        with tf.Graph().as_default(), tflib.create_session().as_default():  # pylint: disable=not-context-manager
            self._report_progress(0, 1)
            time_begin = time.time()
            with dnnlib.util.open_url(network_pkl) as f:
                G, D, Gs = pickle.load(f)

            G_kwargs = dnnlib.EasyDict(G_kwargs)
            G_kwargs.update(self.force_G_kwargs)
            self._evaluate(G=G,
                           D=D,
                           Gs=Gs,
                           G_kwargs=G_kwargs,
                           num_gpus=num_gpus)

            self._eval_time = time.time() - time_begin  # pylint: disable=attribute-defined-outside-init
            self._report_progress(1, 1)
            if self._dataset is not None:
                self._dataset.close()
                self._dataset = None

        result_str = self.get_result_str()
        tqdm.write(result_str)
        if self._run_dir is not None and os.path.isdir(self._run_dir):
            with open(os.path.join(self._run_dir, f'metric-{self.name}.txt'),
                      'at') as f:
                f.write(result_str + '\n')
示例#4
0
    def run(self,
            network_pkl,
            run_dir=None,
            dataset_args=None,
            mirror_augment=None,
            num_gpus=1,
            tf_config=None,
            log_results=True,
            model_type="rignet"):

        create_dir(config.EVALUATION_DIR, exist_ok=True)

        self._network_pkl = network_pkl
        self._dataset_args = dataset_args
        self._mirror_augment = mirror_augment
        self._results = []
        self.model_type = model_type

        if (dataset_args is None
                or mirror_augment is None) and run_dir is not None:
            run_config = misc.parse_config_for_previous_run(run_dir)
            self._dataset_args = dict(run_config['dataset'])
            self._dataset_args['shuffle_mb'] = 0
            self._mirror_augment = run_config['train'].get(
                'mirror_augment', False)

        time_begin = time.time()
        with tf.Graph().as_default(), tflib.create_session(
                tf_config).as_default():  # pylint: disable=not-context-manager
            E, _G, _D, Gs = misc.load_pkl(self._network_pkl)
            print("Loaded Encoder")
            Inv, _, _, _ = misc.load_pkl(config.INVERSION_PICKLE_DIR)
            print("Loaded Inv")
            self._evaluate(Gs, E, Inv, num_gpus=num_gpus)
        self._eval_time = time.time() - time_begin

        if log_results:
            result_str = self.get_result_str()
            if run_dir is not None:
                log = os.path.join(run_dir, 'metric-%s.txt' % self.name)
                with dnnlib.util.Logger(log, 'a'):
                    print(result_str)
            else:
                print(result_str)

            result_path = os.path.join(
                config.EVALUATION_DIR, "result_" +
                convert_pickle_path_to_name(self._network_pkl) + ".txt")
            write_to_file(result_str + "\n\n\n", result_path)
示例#5
0
    def run(self,
            network_pkl,
            num_imgs,
            run_dir=None,
            data_dir=None,
            dataset_args=None,
            mirror_augment=None,
            ratio=1.0,
            num_gpus=1,
            tf_config=None,
            log_results=True,
            Gs_kwargs=dict(is_validation=True),
            eval_mod=False,
            **kwargs):

        self._reset(network_pkl=network_pkl,
                    run_dir=run_dir,
                    data_dir=data_dir,
                    dataset_args=dataset_args,
                    mirror_augment=mirror_augment)
        self.eval_mod = eval_mod

        time_begin = time.time()
        with tf.Graph().as_default(), tflib.create_session(
                tf_config).as_default():
            self._report_progress(0, 1)
            _G = _D = Gs = None
            if self._network_pkl is not None:
                _G, _D, Gs = misc.load_pkl(self._network_pkl)[:3]
            self._evaluate(Gs,
                           Gs_kwargs=Gs_kwargs,
                           num_gpus=num_gpus,
                           num_imgs=num_imgs,
                           ratio=ratio,
                           **kwargs)
            self._report_progress(1, 1)
        self._eval_time = time.time() - time_begin

        if log_results:
            if run_dir is not None:
                log_file = os.path.join(run_dir, "metric-%s.txt" % self.name)
                with dnnlib.util.Logger(log_file, "a", screen=False):
                    print(self.get_result_str().strip())
            print(self.get_result_str(screen=True).strip())

        return self._results[0].value
示例#6
0
    def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True,
            include_I=False, avg_mv_for_I=False, Gs_kwargs=dict(is_validation=True, return_atts=False), train_infernet=False, is_vae=False, use_D=False,
            **kwargs):
        self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment)
        time_begin = time.time()
        with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager
            self._report_progress(0, 1)
            if include_I:
                if avg_mv_for_I:
                    _G, _D, _I, Gs, I = misc.load_pkl(self._network_pkl)
                else:
                    _G, _D, I, Gs = misc.load_pkl(self._network_pkl)
                outs = self._evaluate(Gs=Gs, Gs_kwargs=Gs_kwargs, I_net=I, num_gpus=num_gpus, **kwargs)
            elif train_infernet:
                I, Gs = misc.load_pkl(self._network_pkl)
                outs = self._evaluate(Gs=Gs, Gs_kwargs=Gs_kwargs, I_net=I, num_gpus=num_gpus, **kwargs)
            elif is_vae:
                if use_D:
                    I, Gs, D = misc.load_pkl(self._network_pkl)
                else:
                    I, Gs = misc.load_pkl(self._network_pkl)
                outs = self._evaluate(Gs=Gs, Gs_kwargs=Gs_kwargs, I_net=I, num_gpus=num_gpus, **kwargs)
            else:
                _G, _D, Gs = misc.load_pkl(self._network_pkl)
                outs = self._evaluate(Gs=Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus, **kwargs)
            self._report_progress(1, 1)
        self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init

        if log_results:
            if run_dir is not None:
                log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name)
                with dnnlib.util.Logger(log_file, 'a'):
                    print(self.get_result_str().strip())
            else:
                print(self.get_result_str().strip())
        return outs
示例#7
0
def setup_training_options(
        # General options (not included in desc).
        gpus=None,  # Number of GPUs: <int>, default = 1 gpu
        snap=None,  # Snapshot interval: <int>, default = 50 ticks

        # Training dataset.
    data=None,  # Training dataset (required): <path>
        res=None,  # Override dataset resolution: <int>, default = highest available
        mirror=None,  # Augment dataset with x-flips: <bool>, default = False
        mirrory=None,  # Augment dataset with y-flips: <bool>, default = False
        use_raw=None,

        # Metrics (not included in desc).
        metrics=None,  # List of metric names: [], ['fid50k_full'] (default), ...
        metricdata=None,  # Metric dataset (optional): <path>

        # Base config.
    cfg=None,  # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline'
        gamma=None,  # Override R1 gamma: <float>, default = depends on cfg
        kimg=None,  # Override training duration: <int>, default = depends on cfg

        # Discriminator augmentation.
    aug=None,  # Augmentation mode: 'ada' (default), 'noaug', 'fixed', 'adarv'
        p=None,  # Specify p for 'fixed' (required): <float>
        target=None,  # Override ADA target for 'ada' and 'adarv': <float>, default = depends on aug
        augpipe=None,  # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'

        # Comparison methods.
    cmethod=None,  # Comparison method: 'nocmethod' (default), 'bcr', 'zcr', 'pagan', 'wgangp', 'auxrot', 'spectralnorm', 'shallowmap', 'adropout'
        dcap=None,  # Multiplier for discriminator capacity: <float>, default = 1

        # Transfer learning.
    resume=None,  # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
        freezed=None,  # Freeze-D: <int>, default = 0 discriminator layers
):
    # Initialize dicts.
    args = dnnlib.EasyDict()
    args.G_args = dnnlib.EasyDict(func_name='training.networks.G_main')
    args.D_args = dnnlib.EasyDict(func_name='training.networks.D_main')
    args.G_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99)
    args.D_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99)
    args.loss_args = dnnlib.EasyDict(func_name='training.loss.stylegan2')
    args.augment_args = dnnlib.EasyDict(
        class_name='training.augment.AdaptiveAugment')

    # ---------------------------
    # General options: gpus, snap
    # ---------------------------

    if gpus is None:
        gpus = 1
    assert isinstance(gpus, int)
    if not (gpus >= 1 and gpus & (gpus - 1) == 0):
        raise UserError('--gpus must be a power of two')
    args.num_gpus = gpus

    if snap is None:
        snap = 50
    assert isinstance(snap, int)
    if snap < 1:
        raise UserError('--snap must be at least 1')
    args.image_snapshot_ticks = snap
    args.network_snapshot_ticks = snap

    # ---------------------------------------------
    # Training dataset: data, res, mirror, mirrory
    # ---------------------------------------------

    assert data is not None
    assert isinstance(data, str)
    data_name = os.path.basename(os.path.abspath(data))
    if not os.path.isdir(data) or len(data_name) == 0:
        raise UserError(
            '--data must point to a directory containing *.tfrecords')
    desc = data_name

    with tf.Graph().as_default(), tflib.create_session().as_default():  # pylint: disable=not-context-manager
        args.train_dataset_args = dnnlib.EasyDict(path=data,
                                                  max_label_size='full')
        args.train_dataset_args.use_raw = use_raw
        dataset_obj = dataset.load_dataset(
            **args.train_dataset_args
        )  # try to load the data and see what comes out
        args.train_dataset_args.resolution = dataset_obj.shape[
            -1]  # be explicit about resolution
        args.train_dataset_args.max_label_size = dataset_obj.label_size  # be explicit about label size
        validation_set_available = dataset_obj.has_validation_set
        dataset_obj.close()
        dataset_obj = None

    if res is None:
        res = args.train_dataset_args.resolution
    else:
        assert isinstance(res, int)
        if not (res >= 4 and res & (res - 1) == 0):
            raise UserError('--res must be a power of two and at least 4')
        if res > args.train_dataset_args.resolution:
            raise UserError(
                f'--res cannot exceed maximum available resolution in the dataset ({args.train_dataset_args.resolution})'
            )
        desc += f'-res{res:d}'
    args.train_dataset_args.resolution = res

    if mirror is None:
        mirror = False
    assert isinstance(mirror, bool)
    if mirror:
        desc += '-mirror'
    args.train_dataset_args.mirror_augment = mirror

    if mirrory is None:
        mirrory = False
    assert isinstance(mirrory, bool)
    if mirrory:
        desc += '-mirrory'
    args.train_dataset_args.mirror_augment_v = mirrory

    args.train_dataset_args.use_raw = use_raw

    # ----------------------------
    # Metrics: metrics, metricdata
    # ----------------------------

    if metrics is None:
        metrics = ['fid50k_full']
    assert isinstance(metrics, list)
    assert all(isinstance(metric, str) for metric in metrics)

    args.metric_arg_list = []
    for metric in metrics:
        if metric not in metric_defaults.metric_defaults:
            raise UserError('\n'.join(
                ['--metrics can only contain the following values:', 'none'] +
                list(metric_defaults.metric_defaults.keys())))
        args.metric_arg_list.append(metric_defaults.metric_defaults[metric])

    args.metric_dataset_args = dnnlib.EasyDict(args.train_dataset_args)
    if metricdata is not None:
        assert isinstance(metricdata, str)
        if not os.path.isdir(metricdata):
            raise UserError(
                '--metricdata must point to a directory containing *.tfrecords'
            )
        args.metric_dataset_args.path = metricdata

    # -----------------------------
    # Base config: cfg, gamma, kimg
    # -----------------------------

    if cfg is None:
        cfg = 'auto'
    assert isinstance(cfg, str)
    desc += f'-{cfg}'

    cfg_specs = {
        'auto':
        dict(ref_gpus=-1,
             kimg=25000,
             mb=-1,
             mbstd=-1,
             fmaps=-1,
             lrate=-1,
             gamma=-1,
             ema=-1,
             ramp=0.05,
             map=8),  # populated dynamically based on 'gpus' and 'res'
        '11gb-gpu':
        dict(ref_gpus=1,
             kimg=25000,
             mb=4,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # uses mixed-precision, 11GB GPU
        '11gb-gpu-complex':
        dict(ref_gpus=1,
             kimg=25000,
             mb=4,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # uses mixed-precision, 11GB GPU
        '24gb-gpu':
        dict(ref_gpus=1,
             kimg=25000,
             mb=8,
             mbstd=8,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # uses mixed-precision, 24GB GPU
        '24gb-gpu-complex':
        dict(ref_gpus=1,
             kimg=25000,
             mb=8,
             mbstd=8,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # uses mixed-precision, 24GB GPU
        '48gb-gpu':
        dict(ref_gpus=1,
             kimg=25000,
             mb=16,
             mbstd=16,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # uses mixed-precision, 48GB GPU
        'stylegan2':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # uses mixed-precision, unlike original StyleGAN2
        'paper256':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=0.5,
             lrate=0.0025,
             gamma=1,
             ema=20,
             ramp=None,
             map=8),
        'paper512':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=1,
             lrate=0.0025,
             gamma=0.5,
             ema=20,
             ramp=None,
             map=8),
        'paper1024':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=2,
             ema=10,
             ramp=None,
             map=8),
        'cifar':
        dict(ref_gpus=2,
             kimg=100000,
             mb=64,
             mbstd=32,
             fmaps=0.5,
             lrate=0.0025,
             gamma=0.01,
             ema=500,
             ramp=0.05,
             map=2),
        'cifarbaseline':
        dict(ref_gpus=2,
             kimg=100000,
             mb=64,
             mbstd=32,
             fmaps=0.5,
             lrate=0.0025,
             gamma=0.01,
             ema=500,
             ramp=0.05,
             map=8),
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto':
        desc += f'{gpus:d}'
        spec.ref_gpus = gpus
        spec.mb = max(min(gpus * min(4096 // res, 32), 64),
                      gpus)  # keep gpu memory consumption at bay
        spec.mbstd = min(
            spec.mb // gpus, 4
        )  # other hyperparams behave more predictably if mbstd group size remains fixed
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 else 0.0025
        spec.gamma = 0.0002 * (res**2) / spec.mb  # heuristic formula
        spec.ema = spec.mb * 10 / 32

    args.total_kimg = spec.kimg
    args.minibatch_size = spec.mb
    args.minibatch_gpu = spec.mb // spec.ref_gpus
    args.D_args.mbstd_group_size = spec.mbstd
    args.G_args.fmap_base = args.D_args.fmap_base = int(spec.fmaps * 16384)
    args.G_args.fmap_max = args.D_args.fmap_max = 512
    args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = spec.lrate
    args.loss_args.r1_gamma = spec.gamma
    args.G_smoothing_kimg = spec.ema
    args.G_smoothing_rampup = spec.ramp
    args.G_args.mapping_layers = spec.map
    args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4  # enable mixed-precision training
    args.G_args.conv_clamp = args.D_args.conv_clamp = 256  # clamp activations to avoid float16 overflow

    if cfg == 'cifar' or cfg.split('-')[-1] == 'complex':
        args.loss_args.pl_weight = 0  # disable path length regularization
        args.G_args.style_mixing_prob = None  # disable style mixing
        args.D_args.architecture = 'orig'  # disable residual skip connections

    if gamma is not None:
        assert isinstance(gamma, float)
        if not gamma >= 0:
            raise UserError('--gamma must be non-negative')
        desc += f'-gamma{gamma:g}'
        args.loss_args.r1_gamma = gamma

    if kimg is not None:
        assert isinstance(kimg, int)
        if not kimg >= 1:
            raise UserError('--kimg must be at least 1')
        desc += f'-kimg{kimg:d}'
        args.total_kimg = kimg

    # ---------------------------------------------------
    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug is None:
        aug = 'ada'
    else:
        assert isinstance(aug, str)
        desc += f'-{aug}'

    if aug == 'ada':
        args.augment_args.tune_heuristic = 'rt'
        args.augment_args.tune_target = 0.6

    elif aug == 'noaug':
        pass

    elif aug == 'fixed':
        if p is None:
            raise UserError(f'--aug={aug} requires specifying --p')

    elif aug == 'adarv':
        if not validation_set_available:
            raise UserError(
                f'--aug={aug} requires separate validation set; please see "python dataset_tool.py pack -h"'
            )
        args.augment_args.tune_heuristic = 'rv'
        args.augment_args.tune_target = 0.5

    else:
        raise UserError(f'--aug={aug} not supported')

    if p is not None:
        assert isinstance(p, float)
        if aug != 'fixed':
            raise UserError('--p can only be specified with --aug=fixed')
        if not 0 <= p <= 1:
            raise UserError('--p must be between 0 and 1')
        desc += f'-p{p:g}'
        args.augment_args.initial_strength = p

    if target is not None:
        assert isinstance(target, float)
        if aug not in ['ada', 'adarv']:
            raise UserError(
                '--target can only be specified with --aug=ada or --aug=adarv')
        if not 0 <= target <= 1:
            raise UserError('--target must be between 0 and 1')
        desc += f'-target{target:g}'
        args.augment_args.tune_target = target

    assert augpipe is None or isinstance(augpipe, str)
    if augpipe is None:
        augpipe = 'bgc'
    else:
        if aug == 'noaug':
            raise UserError('--augpipe cannot be specified with --aug=noaug')
        desc += f'-{augpipe}'

    augpipe_specs = {
        'blit':
        dict(xflip=1, rotate90=1, xint=1),
        'geom':
        dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':
        dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter':
        dict(imgfilter=1),
        'noise':
        dict(noise=1),
        'cutout':
        dict(cutout=1),
        'bg':
        dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1),
        'bgcf':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1),
        'bgcfn':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1),
        'bgcfnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1,
             cutout=1),
    }

    assert augpipe in augpipe_specs
    if aug != 'noaug':
        args.augment_args.apply_func = 'training.augment.augment_pipeline'
        args.augment_args.apply_args = augpipe_specs[augpipe]

    # ---------------------------------
    # Comparison methods: cmethod, dcap
    # ---------------------------------

    assert cmethod is None or isinstance(cmethod, str)
    if cmethod is None:
        cmethod = 'nocmethod'
    else:
        desc += f'-{cmethod}'

    if cmethod == 'nocmethod':
        pass

    elif cmethod == 'bcr':
        args.loss_args.func_name = 'training.loss.cmethods'
        args.loss_args.bcr_real_weight = 10
        args.loss_args.bcr_fake_weight = 10
        args.loss_args.bcr_augment = dnnlib.EasyDict(
            func_name='training.augment.augment_pipeline',
            xint=1,
            xint_max=1 / 32)

    elif cmethod == 'zcr':
        args.loss_args.func_name = 'training.loss.cmethods'
        args.loss_args.zcr_gen_weight = 0.02
        args.loss_args.zcr_dis_weight = 0.2
        args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0  # disable mixed-precision training
        args.G_args.conv_clamp = args.D_args.conv_clamp = None

    elif cmethod == 'pagan':
        if aug != 'noaug':
            raise UserError(
                f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug'
            )
        args.D_args.use_pagan = True
        args.augment_args.tune_heuristic = 'rt'  # enable ada heuristic
        args.augment_args.pop('apply_func',
                              None)  # disable discriminator augmentation
        args.augment_args.pop('apply_args', None)
        args.augment_args.tune_target = 0.95

    elif cmethod == 'wgangp':
        if aug != 'noaug':
            raise UserError(
                f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug'
            )
        if gamma is not None:
            raise UserError(
                f'--cmethod={cmethod} is not compatible with --gamma')
        args.loss_args = dnnlib.EasyDict(func_name='training.loss.wgangp')
        args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = 0.001
        args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0  # disable mixed-precision training
        args.G_args.conv_clamp = args.D_args.conv_clamp = None
        args.lazy_regularization = False

    elif cmethod == 'auxrot':
        if args.train_dataset_args.max_label_size > 0:
            raise UserError(
                f'--cmethod={cmethod} is not compatible with label conditioning; please specify a dataset without labels'
            )
        args.loss_args.func_name = 'training.loss.cmethods'
        args.loss_args.auxrot_alpha = 10
        args.loss_args.auxrot_beta = 5
        args.D_args.score_max = 5  # prepare D to output 5 scalars per image instead of just 1

    elif cmethod == 'spectralnorm':
        args.D_args.use_spectral_norm = True

    elif cmethod == 'shallowmap':
        if args.G_args.mapping_layers == 2:
            raise UserError(f'--cmethod={cmethod} is a no-op for --cfg={cfg}')
        args.G_args.mapping_layers = 2

    elif cmethod == 'adropout':
        if aug != 'noaug':
            raise UserError(
                f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug'
            )
        args.D_args.adaptive_dropout = 1
        args.augment_args.tune_heuristic = 'rt'  # enable ada heuristic
        args.augment_args.pop('apply_func',
                              None)  # disable discriminator augmentation
        args.augment_args.pop('apply_args', None)
        args.augment_args.tune_target = 0.6

    else:
        raise UserError(f'--cmethod={cmethod} not supported')

    if dcap is not None:
        assert isinstance(dcap, float)
        if not dcap > 0:
            raise UserError('--dcap must be positive')
        desc += f'-dcap{dcap:g}'
        args.D_args.fmap_base = max(int(args.D_args.fmap_base * dcap), 1)
        args.D_args.fmap_max = max(int(args.D_args.fmap_max * dcap), 1)

    # ----------------------------------
    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        'ffhq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
        'afhqcat512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqcat.pkl',
        'afhqdog512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqdog.pkl',
        'afhqwild512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqwild.pkl',
        'brecahad512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/brecahad.pkl',
        'cifar10':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl',
        'metfaces':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl',
    }

    assert resume is None or isinstance(resume, str)
    if resume is None:
        resume = 'noresume'
    elif resume == 'noresume':
        desc += '-noresume'
    elif resume in resume_specs:
        desc += f'-resume{resume}'
        args.resume_pkl = resume_specs[resume]  # predefined url
    else:
        desc += '-resumecustom'
        args.resume_pkl = resume  # custom path or url

    if resume != 'noresume':
        args.augment_args.tune_kimg = 100  # make ADA react faster at the beginning
        args.G_smoothing_rampup = None  # disable EMA rampup

    if freezed is not None:
        assert isinstance(freezed, int)
        if not freezed >= 0:
            raise UserError('--freezed must be non-negative')
        desc += f'-freezed{freezed:d}'
        args.D_args.freeze_layers = freezed

    return desc, args
示例#8
0
    def load_network(self, network_path: str) -> None:
        self.session = tflib.create_session(None, force_as_default=True)

        # Load Networks
        self.image_resolution = '256'
        replace_res = re.findall('256|512', network_path)
        if replace_res:
            self.image_resolution = int(replace_res[0])

        self.current_network_checkpoint = network_path
        _, _, self.Gs = misc.load_pkl('../../results/' + network_path)
        self._vgg16_model = misc.load_pkl(
            'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/vgg16_zhang_perceptual.pkl'
        )

        # Define Placeholders
        self.latent_placeholder = tf.placeholder(tf.float32,
                                                 shape=self.Gs.input_shapes[0])
        self.label_placeholder = tf.placeholder(tf.float32,
                                                shape=self.Gs.input_shapes[1])
        self.dlatent_placeholder = tf.placeholder(
            tf.float32, shape=self.Gs.components.synthesis.input_shape)
        self.dlatent_left_placeholder = tf.placeholder(
            tf.float32, self.Gs.components.synthesis.input_shape)
        self.dlatent_right_placeholder = tf.placeholder(
            tf.float32, self.Gs.components.synthesis.input_shape)
        self.magnitude_placeholder = tf.placeholder(tf.float32, shape=())
        self.delta_magnitude_placeholder = tf.placeholder(tf.float32, shape=())

        if self.Gs.input_shapes[1][1] == 121:
            self.label_version = 'v7'
        if self.Gs.input_shapes[1][1] == 127:
            self.label_version = 'v5'
        read_label_version = re.findall('v\d+', network_path)
        if read_label_version:
            self.label_version = read_label_version[0]

        # Generate Image
        self._output_image = self.Gs.get_output_for(self.latent_placeholder,
                                                    self.label_placeholder,
                                                    randomize_noise=False,
                                                    truncation_psi=None,
                                                    style_mixing_prob=None)
        self._synthesis = self.Gs.components.synthesis.get_output_for(
            self.dlatent_placeholder, randomize_noise=False)

        if 'baseline_without_labels' in self.current_network_checkpoint:
            zero_label = tf.zeros(shape=[1, 0], dtype=tf.float32)
            self._mapping_without_labels = self.Gs.components.mapping.get_output_for(
                self.latent_placeholder, zero_label)

        # Separate Mapping
        if 'mapping' in self.Gs.components:
            self.separate_mapping = False
            self._mapping = self.Gs.components.mapping.get_output_for(
                self.latent_placeholder, self.label_placeholder)
        else:
            self.separate_mapping = True
            self.concat = self.Gs.components.synthesis.input_shape[2] > 512
            self._latent_separate_mapping = self.Gs.components.mapping_latent.get_output_for(
                self.latent_placeholder)
            self._label_separate_mapping = self.Gs.components.mapping_label.get_output_for(
                self.label_placeholder)

        # Interpolation Gradient Graph
        dlatent_int_delta, dlatent_int = utils.delta_lerp(
            a=self.dlatent_left_placeholder,
            b=self.dlatent_right_placeholder,
            mag=self.magnitude_placeholder,
            delta_mag=self.delta_magnitude_placeholder,
        )
        image_int_delta = self.Gs.components.synthesis.get_output_for(
            dlatent_int_delta, randomize_noise=False)
        self._gradients = tf.gradients(image_int_delta, [dlatent_int])[0]

        # VGG16 Graph
        image_int = self.Gs.components.synthesis.get_output_for(
            dlatent_int, randomize_noise=False)
        self._vgg16_distance = self._vgg16_model.get_output_for(
            image_int, image_int_delta)

        print('Successfully loaded:', self.current_network_checkpoint)
示例#9
0
def run(data, train_dir, config, d_aug, diffaug_policy, cond, ops, mirror, mirror_v, \
        kimg, batch_size, lrate, resume, resume_kimg, num_gpus, ema_kimg, gamma, freezeD):

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # dataset (tfrecords) - get or create
    tfr_files = file_list(os.path.dirname(data), 'tfr')
    tfr_files = [
        f for f in tfr_files if basename(data) == basename(f).split('-')[0]
    ]
    if len(tfr_files) == 0 or os.stat(tfr_files[0]).st_size == 0:
        tfr_file, total_samples = create_from_image_folders(
            data) if cond is True else create_from_images(data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file)

    # resolutions
    with tf.Graph().as_default(), tflib.create_session().as_default():  # pylint: disable=not-context-manager
        dataset_obj = dataset.load_dataset(
            **dataset_args)  # loading the data to see what comes out
        resolution = dataset_obj.resolution
        init_res = dataset_obj.init_res
        res_log2 = dataset_obj.res_log2
        dataset_obj.close()
        dataset_obj = None

    if list(init_res) == [4, 4]:
        desc = '%s-%d' % (basename(data), resolution)
    else:
        print(' custom init resolution', init_res)
        desc = basename(tfr_file)
    G.init_res = D.init_res = list(init_res)

    train.savenames = [desc.replace(basename(data), 'snapshot'), desc]
    desc += '-%s' % config

    # training schedule
    train.total_kimg = kimg
    train.image_snapshot_ticks = 1 * num_gpus if kimg <= 1000 else 4 * num_gpus
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v
    sched.tick_kimg_base = 2 if train.total_kimg < 2000 else 4

    # learning rate
    if config == 'e':
        sched.G_lrate_base = 0.001
        sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
        sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        sched.G_lrate_base = lrate  # 0.001 for big datasets, 0.0003 for few-shot
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    # batch size (for 16gb memory GPU)
    sched.minibatch_gpu_base = 4096 // resolution if batch_size is None else batch_size
    print(' Batch size', sched.minibatch_gpu_base)
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
    def run(self,
            network_pkl,
            run_dir=None,
            data_dir=None,
            dataset_args=None,
            mirror_augment=None,
            num_gpus=1,
            tf_config=None,
            log_results=True,
            num_repeats=1,
            Gs_kwargs=dict(is_validation=True),
            resume_with_new_nets=False,
            truncations=[None]):
        self._reset(network_pkl=network_pkl,
                    run_dir=run_dir,
                    data_dir=data_dir,
                    dataset_args=dataset_args,
                    mirror_augment=mirror_augment)
        with tf.Graph().as_default(), tflib.create_session(
                tf_config).as_default():  # pylint: disable=not-context-manager
            self._report_progress(0, 1)
            _G, _D, Gs = misc.load_pkl(self._network_pkl)

            if resume_with_new_nets:
                dataset = self._get_dataset_obj()
                G = dnnlib.tflib.Network(
                    'G',
                    num_channels=dataset.shape[0],
                    resolution=dataset.shape[1],
                    label_size=dataset.label_size,
                    func_name='training.co_mod_gan.G_main',
                    pix2pix=dataset.pix2pix)
                Gs_new = G.clone('Gs')
                Gs_new.copy_vars_from(Gs)
                Gs = Gs_new

            for t in truncations:
                print('truncation={}'.format(t))
                self._results = []
                time_begin = time.time()

                Gs_kwargs.update(truncation_psi_val=t)
                self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus)
                self._report_progress(1, 1)

                if num_repeats > 1:
                    records = [
                        dnnlib.EasyDict(value=[res.value],
                                        suffix=res.suffix,
                                        fmt=res.fmt) for res in self._results
                    ]
                    for i in range(1, num_repeats):
                        print(self.get_result_str().strip())
                        self._results = []
                        self._report_progress(0, 1)
                        self._evaluate(Gs,
                                       Gs_kwargs=Gs_kwargs,
                                       num_gpus=num_gpus)
                        self._report_progress(1, 1)
                        for rec, res in zip(records, self._results):
                            rec.value.append(res.value)

                    self._results = []
                    for rec in records:
                        self._report_result(np.mean(rec.value), rec.suffix,
                                            rec.fmt)
                        self._report_result(np.std(rec.value),
                                            rec.suffix + '-std', rec.fmt)

                self._eval_time = time.time() - time_begin  # pylint: disable=attribute-defined-outside-init

                if log_results:
                    if run_dir is not None:
                        log_file = os.path.join(run_dir,
                                                'metric-%s.txt' % self.name)
                        with dnnlib.util.Logger(log_file, 'a'):
                            print(self.get_result_str().strip())
                    else:
                        print(self.get_result_str().strip())
示例#11
0
from training import dataset
import numpy as np
import tensorflow as tf
import training.misc as misc
import matplotlib.pyplot as plt

tfrecord_dir = '../../datasets/cars_v5_512'

tflib.init_tf({'gpu_options.allow_growth': True})
training_set = dataset.TFRecordDataset(tfrecord_dir,
                                       max_label_size='full',
                                       repeat=False,
                                       shuffle_mb=0)
tflib.init_uninitialized_vars()

session = tflib.create_session(None, force_as_default=True)
latent_placeholder = tf.placeholder(tf.float32, shape=(None, 512))
dlatent_placeholder = tf.placeholder(tf.float32, shape=(None, 16, 512))
label_placeholder = tf.placeholder(tf.float32, shape=(None, 127))
G, D, Gs = misc.load_pkl(
    '../../results/00155-stylegan2-cars_v5_512-2gpu-config-f/network-snapshot-010467.pkl'
)

num_steps = 10

label_left = training_set.get_random_labels_np(1)
rotation_offset = 108
rotations = label_left[:, rotation_offset:rotation_offset + 8]
rotation_index = np.argmax(rotations, axis=1)
new_rotation_index = ((rotation_index + np.random.choice([-1, 1])) % 8)
new_rotation = np.zeros([8], dtype=np.uint32)