def locate_latest_pkl(train_dir):
    allpickles = sorted([f for f in file_list(train_dir, 'pkl') if 'snapshot' in f])
    if len(allpickles) == 0:
        return None, 0.0
    latest_pickle = allpickles[-1]
    kimg = float(basename(latest_pickle).split('-')[-1])
    return latest_pickle, kimg
def main():
    os.makedirs(a.out_dir, exist_ok=True)
    device = torch.device('cuda')

    # setup generator
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type

    # load base or custom network
    pkl_name = osp.splitext(a.model)[0]
    if '.pkl' in a.model.lower():
        custom = False
        print(' .. Gs from pkl ..', basename(a.model))
    else:
        custom = True
        print(' .. Gs custom ..', basename(a.model))
    with dnnlib.util.open_url(pkl_name + '.pkl') as f:
        Gs = legacy.load_network_pkl(f,
                                     custom=custom, **Gs_kwargs)['G_ema'].to(
                                         device)  # type: ignore

    dlat_shape = (1, Gs.num_ws, Gs.w_dim)  # [1,18,512]

    # read saved latents
    if a.dlatents is not None and osp.isfile(a.dlatents):
        key_dlatents = load_latents(a.dlatents)
        if len(key_dlatents.shape) == 2:
            key_dlatents = np.expand_dims(key_dlatents, 0)
    elif a.dlatents is not None and osp.isdir(a.dlatents):
        # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1]
        key_dlatents = []
        npy_list = file_list(a.dlatents, 'npy')
        for npy in npy_list:
            key_dlatent = load_latents(npy)
            if len(key_dlatent.shape) == 2:
                key_dlatent = np.expand_dims(key_dlatent, 0)
            key_dlatents.append(key_dlatent)
        key_dlatents = np.concatenate(key_dlatents)  # [frm,18,512]
    else:
        print(' No input dlatents found')
        exit()
    key_dlatents = key_dlatents[:, np.newaxis]  # [frm,1,18,512]
    print(' key dlatents', key_dlatents.shape)

    # replace higher layers with single (style) latent
    if a.style_dlat is not None:
        print(' styling with dlatent', a.style_dlat)
        style_dlatent = load_latents(a.style_dlat)
        while len(style_dlatent.shape) < 4:
            style_dlatent = np.expand_dims(style_dlatent, 0)
        # try replacing 5 by other value, less than Gs.num_ws
        key_dlatents[:, :, range(5, Gs.num_ws
                                 ), :] = style_dlatent[:, :,
                                                       range(5, Gs.num_ws), :]

    frames = key_dlatents.shape[0] * a.fstep

    dlatents = latent_anima(dlat_shape,
                            frames,
                            a.fstep,
                            key_latents=key_dlatents,
                            cubic=a.cubic,
                            verbose=True)  # [frm,1,512]
    print(' dlatents', dlatents.shape)
    frame_count = dlatents.shape[0]
    dlatents = torch.from_numpy(dlatents).to(device)

    # distort image by tweaking initial const layer
    if a.digress > 0:
        try:
            init_res = Gs.init_res
        except Exception:
            init_res = (4, 4)  # default initial layer size
        dconst = a.digress * latent_anima([1, Gs.z_dim, *init_res],
                                          frame_count,
                                          a.fstep,
                                          cubic=True,
                                          verbose=False)
    else:
        dconst = np.zeros([frame_count, 1, 1, 1, 1])
    dconst = torch.from_numpy(dconst).to(device)

    # generate images from latent timeline
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):

        # generate multi-latent result
        if custom:
            output = Gs.synthesis(dlatents[i],
                                  None,
                                  dconst[i],
                                  noise_mode='const')
        else:
            output = Gs.synthesis(dlatents[i], noise_mode='const')
        output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
            torch.uint8).cpu().numpy()

        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%06d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()
Exemple #3
0
    def __init__(
        self,
        tfrecord,  # tfrecords file
        resolution=None,  # Dataset resolution, None = autodetect.
        label_file=None,  # Relative path of the labels file, None = autodetect.
        max_label_size=0,  # 0 = no labels, 'full' = full labels, <int> = N first label components.
        repeat=True,  # Repeat dataset indefinitely?
        shuffle_mb=4096,  # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb=2048,  # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=2):  # Number of concurrent threads.

        self.tfrecord = tfrecord
        self.resolution = None
        self.res_log2 = None
        self.shape = []  # [channels, height, width]
        self.dtype = 'uint8'
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = None  # components
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_dataset = None
        self._tf_iterator = None
        self._tf_init_op = None
        self._tf_minibatch_np = None
        self._cur_minibatch = -1

        # List tfrecords files and inspect their shapes.
        assert os.path.isfile(self.tfrecord)

        tfr_file = self.tfrecord
        data_dir = os.path.dirname(tfr_file)

        tfr_opt = tf.python_io.TFRecordOptions(
            tf.python_io.TFRecordCompressionType.NONE)
        for record in tf.python_io.tf_record_iterator(
                tfr_file, tfr_opt):  # in fact only one
            tfr_shape = self.parse_tfrecord_shape(record)  # [c,h,w]
            jpg_data = tfr_shape[0] < 4
            break

        # Autodetect label filename.
        if self.label_file is None:
            # guess = sorted(glob.glob(os.path.join(data_dir, '*.labels')))
            guess = [
                ff for ff in file_list(data_dir, 'labels') if basename(
                    ff).split('-')[0] == basename(tfrecord).split('-')[0]
            ]
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(data_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution
        self.shape = list(tfr_shape)
        max_res = calc_res(tfr_shape[1:])
        self.resolution = resolution if resolution is not None else max_res
        self.res_log2 = int(np.ceil(np.log2(self.resolution)))
        self.init_res = [
            int(s * 2**(2 - self.res_log2)) for s in self.shape[1:]
        ]

        # Load labels
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1 << 30, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name='minibatch_in',
                                                   shape=[])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(
                self._np_labels, name='labels_var')
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)

            dset = tf.data.TFRecordDataset(tfr_file,
                                           compression_type='',
                                           buffer_size=buffer_mb << 20)
            if jpg_data is True:
                dset = dset.map(self.parse_tfrecord_tf_jpg,
                                num_parallel_calls=num_threads)
            else:
                dset = dset.map(self.parse_tfrecord_tf,
                                num_parallel_calls=num_threads)
            dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))

            bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
            if shuffle_mb > 0:
                dset = dset.shuffle((
                    (shuffle_mb << 20) - 1) // bytes_per_item + 1)
            if repeat:
                dset = dset.repeat()
            if prefetch_mb > 0:
                dset = dset.prefetch((
                    (prefetch_mb << 20) - 1) // bytes_per_item + 1)
            dset = dset.batch(self._tf_minibatch_in)
            self._tf_dataset = dset

            # self._tf_iterator = tf.data.Iterator.from_structure(tf.data.get_output_types(self._tf_dataset), tf.data.get_output_shapes(self._tf_dataset),)
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_dataset.output_types, self._tf_dataset.output_shapes)
            self._tf_init_op = self._tf_iterator.make_initializer(
                self._tf_dataset)
def main():
    if a.vector_dir is not None:
        if a.vector_dir.endswith('/') or a.vector_dir.endswith('\\'): a.vector_dir = a.vector_dir[:-1]
    os.makedirs(osp.join(a.out_dir, 'ttt'), exist_ok=True)
        
    global Gs, use_d
        
    # parse filename to model parameters
    mparams = basename(a.model).split('-')
    res = int(mparams[1])
    cfg = mparams[2]
    
    # setup generator
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.func_name = 'training.stylegan2_custom.G_main'
    Gs_kwargs.verbose = False
    Gs_kwargs.resolution = res
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type
    Gs_kwargs.latent_size = a.latent_size
    
    if cfg.lower() == 'f':
        Gs_kwargs.synthesis_func = 'G_synthesis_stylegan2'
    elif cfg.lower() == 'e':
        Gs_kwargs.synthesis_func = 'G_synthesis_stylegan2'
        Gs_kwargs.fmap_base = 8 << 10
    else:
        print(' old modes [A-D] not implemented'); exit()
    
    # check initial model resolution
    if len(mparams) > 3: 
        if 'x' in mparams[3].lower():
            init_res = [int(x) for x in mparams[3].lower().split('x')]
            Gs_kwargs.init_res = list(reversed(init_res)) # [H,W] !!! custom res
    
    # load model, check channels
    sess = tflib.init_tf({'allow_soft_placement':True})
    pkl_name = osp.splitext(a.model)[0]
    with open(pkl_name + '.pkl', 'rb') as file:
        network = pickle.load(file, encoding='latin1')
    try: _, _, Gs = network
    except:    Gs = network
    Gs_kwargs.num_channels = Gs.output_shape[1]

    # reload custom network, if needed
    if '.pkl' in a.model.lower(): 
        print(' .. Gs from pkl ..')
    else: 
        print(' .. Gs custom ..')
        Gs = tflib.Network('Gs', **Gs_kwargs)
        Gs.copy_vars_from(network)

    # load directions
    if a.vector_dir is not None:
        directions = []
        vector_list = file_list(a.vector_dir, 'npy')
        for v in vector_list: 
            direction = load_latents(v)
            if len(direction.shape) == 2: direction = np.expand_dims(direction, 0)
            directions.append(direction)
        directions = np.concatenate(directions)[:, np.newaxis] # [frm,1,18,512]
    else:
        print(' No vectors found'); exit()

    if len(direction[0].shape) > 1 and direction[0].shape[0] > 1: 
        use_d = True
    print(' directions', directions.shape, 'using d' if use_d else 'using w')
    
    # latent direction range 
    lrange = [-0.5, 0.5]

    # load saved latents
    if a.npy_file is not None:
        base_latent = load_latents(a.npy_file)
    else:
        print(' No NPY input given, making random')
        z_dim = Gs.input_shape[1]
        shape = (1, z_dim)
        base_latent = np.random.randn(*shape)
        if use_d:
            base_latent = Gs.components.mapping.run(base_latent, None) # [frm,18,512]

    for i, direction in enumerate(directions):
        make_loop(base_latent, direction, lrange, a.fstep*2, a.fstep*2 * i)
def main():
    os.makedirs(a.out_dir, exist_ok=True)

    # setup generator
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.func_name = 'training.stylegan2_multi.G_main'
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type
    Gs_kwargs.impl = a.ops

    # load model with arguments
    sess = tflib.init_tf({'allow_soft_placement': True})
    pkl_name = osp.splitext(a.model)[0]
    with open(pkl_name + '.pkl', 'rb') as file:
        network = pickle.load(file, encoding='latin1')
    try:
        _, _, network = network
    except:
        pass
    for k in list(network.static_kwargs.keys()):
        Gs_kwargs[k] = network.static_kwargs[k]

    # reload custom network, if needed
    if '.pkl' in a.model.lower():
        print(' .. Gs from pkl ..', basename(a.model))
        Gs = network
    else:  # reconstruct network
        print(' .. Gs custom ..', basename(a.model))
        Gs = tflib.Network('Gs', **Gs_kwargs)
        Gs.copy_vars_from(network)

    z_dim = Gs.input_shape[1]
    dz_dim = 512  # dlatent_size
    try:
        dl_dim = 2 * (int(np.floor(np.log2(Gs_kwargs.resolution))) - 1)
    except:
        print(' Resave model, no resolution kwarg found!')
        exit(1)
    dlat_shape = (1, dl_dim, dz_dim)  # [1,18,512]

    # read saved latents
    if a.dlatents is not None and osp.isfile(a.dlatents):
        key_dlatents = load_latents(a.dlatents)
        if len(key_dlatents.shape) == 2:
            key_dlatents = np.expand_dims(key_dlatents, 0)
    elif a.dlatents is not None and osp.isdir(a.dlatents):
        # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1]
        key_dlatents = []
        npy_list = file_list(a.dlatents, 'npy')
        for npy in npy_list:
            key_dlatent = load_latents(npy)
            if len(key_dlatent.shape) == 2:
                key_dlatent = np.expand_dims(key_dlatent, 0)
            key_dlatents.append(key_dlatent)
        key_dlatents = np.concatenate(key_dlatents)  # [frm,18,512]
    else:
        print(' No input dlatents found')
        exit()
    key_dlatents = key_dlatents[:, np.newaxis]  # [frm,1,18,512]
    print(' key dlatents', key_dlatents.shape)

    # replace higher layers with single (style) latent
    if a.style_npy_file is not None:
        print(' styling with latent', a.style_npy_file)
        style_dlatent = load_latents(a.style_npy_file)
        while len(style_dlatent.shape) < 4:
            style_dlatent = np.expand_dims(style_dlatent, 0)
        # try replacing 5 by other value, less than dl_dim
        key_dlatents[:, :,
                     range(5, dl_dim), :] = style_dlatent[:, :,
                                                          range(5, dl_dim), :]

    frames = key_dlatents.shape[0] * a.fstep

    dlatents = latent_anima(dlat_shape,
                            frames,
                            a.fstep,
                            key_latents=key_dlatents,
                            cubic=a.cubic,
                            verbose=True)  # [frm,1,512]
    print(' dlatents', dlatents.shape)
    frame_count = dlatents.shape[0]

    # truncation trick
    dlatent_avg = Gs.get_var('dlatent_avg')  # (512,)
    tr_range = range(0, 8)
    dlatents[:, :, tr_range, :] = dlatent_avg + (dlatents[:, :, tr_range, :] -
                                                 dlatent_avg) * a.trunc

    # distort image by tweaking initial const layer
    if a.digress > 0:
        try:
            latent_size = Gs.static_kwargs['latent_size']
        except:
            latent_size = 512  # default latent size
        try:
            init_res = Gs.static_kwargs['init_res']
        except:
            init_res = (4, 4)  # default initial layer size
        dconst = a.digress * latent_anima([1, latent_size, *init_res],
                                          frames,
                                          a.fstep,
                                          cubic=True,
                                          verbose=False)
    else:
        dconst = np.zeros([frame_count, 1, 1, 1, 1])

    # generate images from latent timeline
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):

        if a.digress is True:
            tf.get_default_session().run(tf.assign(wvars[0], wts[i]))

        # generate multi-latent result
        if Gs.num_inputs == 2:
            output = Gs.components.synthesis.run(dlatents[i],
                                                 randomize_noise=False,
                                                 output_transform=fmt,
                                                 minibatch_size=1)
        else:
            output = Gs.components.synthesis.run(dlatents[i], [None],
                                                 dconst[i],
                                                 randomize_noise=False,
                                                 output_transform=fmt,
                                                 minibatch_size=1)

        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%06d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()
Exemple #6
0
def run(dataset, train_dir, config, d_aug, diffaug_policy, cond, ops, jpg_data, mirror, mirror_v, \
        lod_step_kimg, batch_size, resume, resume_kimg, finetune, num_gpus, ema_kimg, gamma, freezeD):

    # dataset (tfrecords) - preprocess or get
    tfr_files = file_list(os.path.dirname(dataset), 'tfr')
    tfr_files = [f for f in tfr_files if basename(dataset) in f]
    if len(tfr_files) == 0:
        tfr_file, total_samples = create_from_images(dataset, jpg=jpg_data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file, jpg_data=jpg_data)

    desc = basename(tfr_file).split('-')[0]

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # resolutions
    data_res = basename(tfr_file).split('-')[-1].split(
        'x')  # get resolution from dataset filename
    data_res = list(reversed([int(x)
                              for x in data_res]))  # convert to int list
    init_res, resolution, res_log2 = calc_init_res(data_res)
    if init_res != [4, 4]:
        print(' custom init resolution', init_res)
    G.init_res = D.init_res = list(init_res)

    train.setname = desc + config
    desc = '%s-%d-%s' % (desc, resolution, config)

    # training schedule
    sched.lod_training_kimg = lod_step_kimg
    sched.lod_transition_kimg = lod_step_kimg
    train.total_kimg = lod_step_kimg * res_log2 * 2  # a la ProGAN
    if finetune:
        train.total_kimg = 15000  # should start from ~10k kimg
    train.image_snapshot_ticks = 1
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v

    # learning rate
    if config == 'e':
        if finetune:  # uptrain 1024
            sched.G_lrate_base = 0.001
        else:  # train 1024
            sched.G_lrate_base = 0.001
            sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
            sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        # sched.G_lrate_base = 0.0003
        sched.G_lrate_base = 0.001
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    sched.minibatch_gpu_base = batch_size
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Exemple #7
0
def run(data, train_dir, config, d_aug, diffaug_policy, cond, ops, mirror, mirror_v, \
        kimg, batch_size, lrate, resume, resume_kimg, num_gpus, ema_kimg, gamma, freezeD):

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # dataset (tfrecords) - get or create
    tfr_files = file_list(os.path.dirname(data), 'tfr')
    tfr_files = [
        f for f in tfr_files if basename(data) == basename(f).split('-')[0]
    ]
    if len(tfr_files) == 0 or os.stat(tfr_files[0]).st_size == 0:
        tfr_file, total_samples = create_from_image_folders(
            data) if cond is True else create_from_images(data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file)

    # resolutions
    with tf.Graph().as_default(), tflib.create_session().as_default():  # pylint: disable=not-context-manager
        dataset_obj = dataset.load_dataset(
            **dataset_args)  # loading the data to see what comes out
        resolution = dataset_obj.resolution
        init_res = dataset_obj.init_res
        res_log2 = dataset_obj.res_log2
        dataset_obj.close()
        dataset_obj = None

    if list(init_res) == [4, 4]:
        desc = '%s-%d' % (basename(data), resolution)
    else:
        print(' custom init resolution', init_res)
        desc = basename(tfr_file)
    G.init_res = D.init_res = list(init_res)

    train.savenames = [desc.replace(basename(data), 'snapshot'), desc]
    desc += '-%s' % config

    # training schedule
    train.total_kimg = kimg
    train.image_snapshot_ticks = 1 * num_gpus if kimg <= 1000 else 4 * num_gpus
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v
    sched.tick_kimg_base = 2 if train.total_kimg < 2000 else 4

    # learning rate
    if config == 'e':
        sched.G_lrate_base = 0.001
        sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
        sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        sched.G_lrate_base = lrate  # 0.001 for big datasets, 0.0003 for few-shot
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    # batch size (for 16gb memory GPU)
    sched.minibatch_gpu_base = 4096 // resolution if batch_size is None else batch_size
    print(' Batch size', sched.minibatch_gpu_base)
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def main():
    if a.vector_dir is not None:
        if a.vector_dir.endswith('/') or a.vector_dir.endswith('\\'):
            a.vector_dir = a.vector_dir[:-1]
    os.makedirs(a.out_dir, exist_ok=True)
    device = torch.device('cuda')

    global Gs, use_d, custom

    # setup generator
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type

    # load base or custom network
    pkl_name = osp.splitext(a.model)[0]
    if '.pkl' in a.model.lower():
        custom = False
        print(' .. Gs from pkl ..', basename(a.model))
    else:
        custom = True
        print(' .. Gs custom ..', basename(a.model))
    with dnnlib.util.open_url(pkl_name + '.pkl') as f:
        Gs = legacy.load_network_pkl(f,
                                     custom=custom, **Gs_kwargs)['G_ema'].to(
                                         device)  # type: ignore

    # load directions
    if a.vector_dir is not None:
        directions = []
        vector_list = file_list(a.vector_dir, 'npy')
        for v in vector_list:
            direction = load_latents(v)
            if len(direction.shape) == 2:
                direction = np.expand_dims(direction, 0)
            directions.append(direction)
        directions = np.concatenate(directions)[:,
                                                np.newaxis]  # [frm,1,18,512]
    else:
        print(' No vectors found')
        exit()

    if len(direction[0].shape) > 1 and direction[0].shape[0] > 1:
        use_d = True
    print(' directions', directions.shape, 'using d' if use_d else 'using w')
    directions = torch.from_numpy(directions).to(device)

    # latent direction range
    lrange = [-0.5, 0.5]

    # load saved latents
    if a.base_lat is not None:
        base_latent = load_latents(a.base_lat)
        base_latent = torch.from_numpy(base_latent).to(device)
    else:
        print(' No NPY input given, making random')
        base_latent = np.random.randn(1, Gs.z_dim)
        if use_d:
            base_latent = Gs.mapping(base_latent, None)  # [frm,18,512]

    pbar = ProgressBar(len(directions))
    for i, direction in enumerate(directions):
        make_loop(base_latent, direction, lrange, a.fstep * 2, a.fstep * 2 * i)
        pbar.upd()
Exemple #9
0
def main():
    os.makedirs(a.out_dir, exist_ok=True)

    # parse filename to model parameters
    mparams = basename(a.model).split('-')
    res = int(mparams[1])
    cfg = mparams[2]

    # setup generator
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.func_name = 'training.stylegan2_custom.G_main'
    Gs_kwargs.verbose = False
    Gs_kwargs.resolution = res
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type
    Gs_kwargs.latent_size = a.latent_size

    if cfg.lower() == 'f':
        Gs_kwargs.synthesis_func = 'G_synthesis_stylegan2'
    elif cfg.lower() == 'e':
        Gs_kwargs.synthesis_func = 'G_synthesis_stylegan2'
        Gs_kwargs.fmap_base = 8 << 10
    else:
        print(' old modes [A-D] not implemented')
        exit()

    # check initial model resolution
    if len(mparams) > 3:
        if 'x' in mparams[3].lower():
            init_res = [int(x) for x in mparams[3].lower().split('x')]
            Gs_kwargs.init_res = list(reversed(init_res))  # [H,W]

    # load model, check channels
    sess = tflib.init_tf({'allow_soft_placement': True})
    pkl_name = osp.splitext(a.model)[0]
    with open(pkl_name + '.pkl', 'rb') as file:
        network = pickle.load(file, encoding='latin1')
    try:
        _, _, Gs = network
    except:
        Gs = network
    Gs_kwargs.num_channels = Gs.output_shape[1]

    # reload custom network, if needed
    if '.pkl' in a.model.lower():
        print(' .. Gs from pkl ..')
    else:
        print(' .. Gs custom ..')
        Gs = tflib.Network('Gs', **Gs_kwargs)
        Gs.copy_vars_from(network)

    z_dim = Gs.input_shape[1]
    dz_dim = a.dlatent_size  # 512
    dl_dim = 2 * (int(np.floor(np.log2(res))) - 1)
    dlat_shape = (1, dl_dim, dz_dim)  # [1,18,512]

    # read saved latents
    if a.dlatents is not None and osp.isfile(a.dlatents):
        key_dlatents = load_latents(a.dlatents)
        if len(key_dlatents.shape) == 2:
            key_dlatents = np.expand_dims(key_dlatents, 0)
    elif a.dlatents is not None and osp.isdir(a.dlatents):
        # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1]
        key_dlatents = []
        npy_list = file_list(a.dlatents, 'npy')
        for npy in npy_list:
            key_dlatent = load_latents(npy)
            if len(key_dlatent.shape) == 2:
                key_dlatent = np.expand_dims(key_dlatent, 0)
            key_dlatents.append(key_dlatent)
        key_dlatents = np.concatenate(key_dlatents)  # [frm,18,512]
    else:
        print(' No input dlatents found')
        exit()
    key_dlatents = key_dlatents[:, np.newaxis]  # [frm,1,18,512]
    print(' key dlatents', key_dlatents.shape)

    # replace higher layers with single (style) latent
    if a.style_npy_file is not None:
        print(' styling with latent', a.style_npy_file)
        style_dlatent = load_latents(a.style_npy_file)
        while len(style_dlatent.shape) < 4:
            style_dlatent = np.expand_dims(style_dlatent, 0)
        # try other values < dl_dim besides 5
        key_dlatents[:, :,
                     range(5, dl_dim), :] = style_dlatent[:, :,
                                                          range(5, dl_dim), :]

    frames = key_dlatents.shape[0] * a.fstep

    dlatents = latent_anima(dlat_shape,
                            frames,
                            a.fstep,
                            key_latents=key_dlatents,
                            cubic=a.cubic,
                            verbose=True)  # [frm,1,512]
    print(' dlatents', dlatents.shape)

    # truncation trick
    dlatent_avg = Gs.get_var('dlatent_avg')  # (512,)
    tr_range = range(0, 8)
    dlatents[:, :, tr_range, :] = dlatent_avg + (dlatents[:, :, tr_range, :] -
                                                 dlatent_avg) * a.trunc

    # loop for graph frame by frame
    frame_count = dlatents.shape[0]
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):

        dlatent = dlatents[i]

        output = Gs.components.synthesis.run(dlatent,
                                             randomize_noise=False,
                                             output_transform=fmt,
                                             minibatch_size=1)

        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%05d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()
def main():
    if a.vector_dir is not None:
        if a.vector_dir.endswith('/') or a.vector_dir.endswith('\\'):
            a.vector_dir = a.vector_dir[:-1]
    os.makedirs(osp.join(a.out_dir, 'ttt'), exist_ok=True)

    global Gs, use_d

    # setup generator
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.func_name = 'training.stylegan2_multi.G_main'
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type
    Gs_kwargs.impl = a.ops

    # load model with arguments
    sess = tflib.init_tf({'allow_soft_placement': True})
    pkl_name = osp.splitext(a.model)[0]
    with open(pkl_name + '.pkl', 'rb') as file:
        network = pickle.load(file, encoding='latin1')
    try:
        _, _, network = network
    except:
        pass
    for k in list(network.static_kwargs.keys()):
        Gs_kwargs[k] = network.static_kwargs[k]

    # reload custom network, if needed
    if '.pkl' in a.model.lower():
        print(' .. Gs from pkl ..', basename(a.model))
        Gs = network
    else:  # reconstruct network
        print(' .. Gs custom ..', basename(a.model))
        Gs = tflib.Network('Gs', **Gs_kwargs)
        Gs.copy_vars_from(network)

    # load directions
    if a.vector_dir is not None:
        directions = []
        vector_list = file_list(a.vector_dir, 'npy')
        for v in vector_list:
            direction = load_latents(v)
            if len(direction.shape) == 2:
                direction = np.expand_dims(direction, 0)
            directions.append(direction)
        directions = np.concatenate(directions)[:,
                                                np.newaxis]  # [frm,1,18,512]
    else:
        print(' No vectors found')
        exit()

    if len(direction[0].shape) > 1 and direction[0].shape[0] > 1:
        use_d = True
    print(' directions', directions.shape, 'using d' if use_d else 'using w')

    # latent direction range
    lrange = [-0.5, 0.5]

    # load saved latents
    if a.npy_file is not None:
        base_latent = load_latents(a.npy_file)
    else:
        print(' No NPY input given, making random')
        z_dim = Gs.input_shape[1]
        shape = (1, z_dim)
        base_latent = np.random.randn(*shape)
        if use_d:
            base_latent = Gs.components.mapping.run(base_latent,
                                                    None)  # [frm,18,512]

    for i, direction in enumerate(directions):
        make_loop(base_latent, direction, lrange, a.fstep * 2, a.fstep * 2 * i)