Пример #1
0
def setup_snapshot_image_grid(training_set, drange_net, grid_size=None,
    size    = '1080p',      # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display.
    layout  = 'random'):    # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.

    # Select size.
    if grid_size is None:
        if size == '1080p':
            gw = np.clip(1920 // training_set.shape[2], 3, 32)
            gh = np.clip(1080 // training_set.shape[1], 2, 32)
        if size == '4k':
            gw = np.clip(3840 // training_set.shape[2], 7, 32)
            gh = np.clip(2160 // training_set.shape[1], 4, 32)
    else:
        gw = grid_size[0]
        gh = grid_size[1]

    # Fill in reals and labels.
    reals = np.zeros([gw * gh] + training_set.shape, dtype=np.float32)
    labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype)
    for idx in range(gw * gh):
        x = idx % gw; y = idx // gw
        while True:
            real, label = training_set.get_minibatch_np(1)
            real = real.astype(np.float32)
            real = misc.adjust_dynamic_range(real, training_set.dynamic_range, drange_net)
            if layout == 'row_per_class' and training_set.label_size > 0:
                if label[0, y % training_set.label_size] == 0.0:
                    continue
            reals[idx] = real[0]
            labels[idx] = label[0]
            break

    return (gw, gh), reals, labels
Пример #2
0
def process_reals(x, lod, mirror_augment, drange_data, drange_net):
    with tf.name_scope('ProcessReals'):
        with tf.name_scope('DynamicRange'):
            x = tf.cast(x, tf.float32)
            x = misc.adjust_dynamic_range(x, drange_data, drange_net)
        if mirror_augment:
            with tf.name_scope('MirrorAugment'):
                s = tf.shape(x)
                mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0)
                mask = tf.tile(mask, [1, s[1], s[2], s[3]])
                x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3]))
        with tf.name_scope(
                'FadeLOD'
        ):  # Smooth crossfade between consecutive levels-of-detail.
            s = tf.shape(x)
            y = tf.reshape(x, [-1, s[1], s[2] // 2, 2, s[3] // 2, 2])
            y = tf.reduce_mean(y, axis=[3, 5], keepdims=True)
            y = tf.tile(y, [1, 1, 1, 2, 1, 2])
            y = tf.reshape(y, [-1, s[1], s[2], s[3]])
            x = tfutil.lerp(x, y, lod - tf.floor(lod))
        with tf.name_scope(
                'UpscaleLOD'
        ):  # Upscale to match the expected input/output size of the networks.
            s = tf.shape(x)
            factor = tf.cast(2**tf.floor(lod), tf.int32)
            x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
            x = tf.tile(x, [1, 1, 1, factor, 1, factor])
            x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
        return x
Пример #3
0
    def __init__(self, filename=None, model=None):
        if model is not None:
            # error
            return

        _, _, G = misc.load_pkl(filename)
        self.net = Net()
        self.net.G = G
        self.have_compiled = False
        self.net.labels_var = T.TensorType('float32',
                                           [False] * 512)('labels_var')

        # experiment
        num_example_latents = 10
        self.net.example_latents = train.random_latents(
            num_example_latents, self.net.G.input_shape)
        self.net.example_labels = self.net.example_latents
        self.net.latents_var = T.TensorType(
            'float32',
            [False] * len(self.net.example_latents.shape))('latents_var')
        self.net.labels_var = T.TensorType(
            'float32',
            [False] * len(self.net.example_latents.shape))('labels_var')

        self.net.images_expr = self.net.G.eval(self.net.latents_var,
                                               self.net.labels_var,
                                               ignore_unused_inputs=True)
        self.net.images_expr = misc.adjust_dynamic_range(
            self.net.images_expr, [-1, 1], [0, 1])
        train.imgapi_compile_gen_fn(self.net)

        self.invert_models = def_invert_models(self.net,
                                               layer='conv4',
                                               alpha=0.002)
Пример #4
0
def imgapi_load_net(run_id, snapshot=None, random_seed=1000, num_example_latents=1000, load_dataset=True, compile_gen_fn=True):
    class Net: pass
    net = Net()
    net.result_subdir = misc.locate_result_subdir(run_id)
    net.network_pkl = misc.locate_network_pkl(net.result_subdir, snapshot)
    _, _, net.G = misc.load_pkl(net.network_pkl)

    # Generate example latents and labels.
    np.random.seed(random_seed)
    net.example_latents = random_latents(num_example_latents, net.G.input_shape)
    net.example_labels = np.zeros((num_example_latents, 0), dtype=np.float32)
    net.dynamic_range = [0, 255]
    if load_dataset:
        imgapi_load_dataset(net)

    # Compile Theano func.
    net.latents_var = T.TensorType('float32', [False] * len(net.example_latents.shape))('latents_var')
    net.labels_var  = T.TensorType('float32', [False] * len(net.example_labels.shape)) ('labels_var')

    if hasattr(net.G, 'cur_lod'):
        net.lod = net.G.cur_lod.get_value()
        net.images_expr = net.G.eval(net.latents_var, net.labels_var, min_lod=net.lod, max_lod=net.lod, ignore_unused_inputs=True)
    else:
        net.lod = 0.0
        net.images_expr = net.G.eval(net.latents_var, net.labels_var, ignore_unused_inputs=True)

    net.images_expr = misc.adjust_dynamic_range(net.images_expr, [-1,1], net.dynamic_range)
    if compile_gen_fn:
        imgapi_compile_gen_fn(net)
    return net
def process_reals(x, lod, mirror_augment, drange_data, drange_net):
    with tf.name_scope('ProcessReals'):
        with tf.name_scope('DynamicRange'):
            x = tf.cast(x, tf.float32)
            x = misc.adjust_dynamic_range(x, drange_data, drange_net)
        if mirror_augment:
            with tf.name_scope('MirrorAugment'):
                s = tf.shape(x)
                mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0)
                mask = tf.tile(mask, [1, s[1], s[2], s[3]])
                x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3]))
        with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.
            s = tf.shape(x)
            y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])
            y = tf.reduce_mean(y, axis=[3, 5], keepdims=True)
            y = tf.tile(y, [1, 1, 1, 2, 1, 2])
            y = tf.reshape(y, [-1, s[1], s[2], s[3]])
            x = tfutil.lerp(x, y, lod - tf.floor(lod))
        with tf.name_scope('UpscaleLOD'): # Upscale to match the expected input/output size of the networks.
            s = tf.shape(x)
            factor = tf.cast(2 ** tf.floor(lod), tf.int32)
            x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
            x = tf.tile(x, [1, 1, 1, factor, 1, factor])
            x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
        return x
Пример #6
0
    def train(self):
        while self.sche.cur_img < 1000 * self.total_kimg:
            real = self.dataloader.get_batch().cuda()
            real = propress_real(real, self.sche.lod, self.sche.phase)

            for repeat in range(self.D_repeat):
                loss_d = self.update_D(real, self.sche.batchsize)
                #self.Gs.update_Gs(self.G)
                self.sche.cur_img += self.sche.batchsize
            loss_g = self.update_G(self.sche.batchsize)
            self.sche.cur_img += self.sche.batchsize

            print("loss_d:%.2f loss_g:%.2f" % (loss_d, loss_g))

            self.writer.add_scalar('Loss/loss_d', loss_d, self.sche.cur_img)
            self.writer.add_scalar('Loss/loss_g', loss_g, self.sche.cur_img)

            if self.sche.tick % self.sche.save_kimg_tick.get(
                    self.sche.phase, 1) == 0:
                with torch.no_grad():
                    img = self.G(self.z, self.sche.lod, self.sche.phase,
                                 self.sche.TRANSITION)
                    img = upscale_phase(img, self.sche.phase,
                                        self.sche.max_resolution)
                    os.makedirs(os.path.join(self.tag_dir, "images"),
                                exist_ok=True)
                    save_image_grid(
                        adjust_dynamic_range(img, [-1, 1], [0, 1]).clamp(0, 1),
                        os.path.join(
                            self.tag_dir, "images", "phase%d_cur%d.jpg" %
                            (self.sche.phase, self.sche.last_cur_phase_kimg)))
                    save_image_grid(
                        adjust_dynamic_range(
                            upscale_phase(real[:24], self.sche.phase,
                                          self.sche.max_resolution), [-1, 1],
                            [0, 1]).clamp(0, 1),
                        os.path.join(
                            self.tag_dir, "images", "phase%d_cur%d_real.jpg" %
                            (self.sche.phase, self.sche.last_cur_phase_kimg)))

            self.sche.update()
            if self.sche.NEW_PHASE is True:
                self.save_model(self.sche.phase)
                self.G.Add_To_RGB_Layer(self.sche.phase)
                self.D.Add_From_RGB_Layer(self.sche.phase)
                self.renew()
Пример #7
0
import theano
import lasagne 
import dataset
import network
from theano import tensor as T
import config
import misc
import numpy as np
import scipy.ndimage
_, _, G = misc.load_pkl("network-snapshot-009041.pkl")

class Net: pass

net = Net()
net.G = G

import train

num_example_latents = 10
net.example_latents = train.random_latents(num_example_latents, net.G.input_shape)
net.example_labels = net.example_latents
net.latents_var = T.TensorType('float32', [False] * len(net.example_latents.shape))('latents_var')
net.labels_var  = T.TensorType('float32', [False] * len(net.example_latents.shape)) ('labels_var')

print("HIYA", net.example_latents[:1].shape, net.example_labels[:1].shape)
net.images_expr = net.G.eval(net.latents_var, net.labels_var, ignore_unused_inputs=True)
net.images_expr = misc.adjust_dynamic_range(net.images_expr, [-1,1], [0,1])
train.imgapi_compile_gen_fn(net)
images = net.gen_fn(net.example_latents[:1], net.example_labels[:1])
misc.save_image(images[0], "fake4c.png", drange=[0,1])
Пример #8
0
def train_gan(separate_funcs=False,
              D_training_repeats=1,
              G_learning_rate_max=0.0010,
              D_learning_rate_max=0.0010,
              G_smoothing=0.999,
              adam_beta1=0.0,
              adam_beta2=0.99,
              adam_epsilon=1e-8,
              minibatch_default=16,
              minibatch_overrides={},
              rampup_kimg=40,
              rampdown_kimg=0,
              lod_initial_resolution=4,
              lod_training_kimg=400,
              lod_transition_kimg=400,
              total_kimg=10000,
              dequantize_reals=False,
              gdrop_beta=0.9,
              gdrop_lim=0.5,
              gdrop_coef=0.2,
              gdrop_exp=2.0,
              drange_net=[-1, 1],
              drange_viz=[-1, 1],
              image_grid_size=None,
              tick_kimg_default=50,
              tick_kimg_overrides={
                  32: 20,
                  64: 10,
                  128: 10,
                  256: 5,
                  512: 2,
                  1024: 1
              },
              image_snapshot_ticks=4,
              network_snapshot_ticks=40,
              image_grid_type='default',
              resume_network_pkl=None,
              resume_kimg=0.0,
              resume_time=0.0):

    # Load dataset and build networks.
    training_set, drange_orig = load_dataset()
    if resume_network_pkl:
        print 'Resuming', resume_network_pkl
        G, D, _ = misc.load_pkl(
            os.path.join(config.result_dir, resume_network_pkl))
    else:
        G = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.G)
        D = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.D)
    Gs = G.create_temporally_smoothed_version(beta=G_smoothing,
                                              explicit_updates=True)
    misc.print_network_topology_info(G.output_layers)
    misc.print_network_topology_info(D.output_layers)

    # Setup snapshot image grid.
    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[3], G.output_shape[2]
            image_grid_size = np.clip(1920 / w, 3,
                                      16), np.clip(1080 / h, 2, 16)
        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    elif image_grid_type == 'category':
        W = training_set.labels.shape[1]
        H = W if image_grid_size is None else image_grid_size[1]
        image_grid_size = W, H
        snapshot_fake_latents = random_latents(W * H, G.input_shape)
        snapshot_fake_labels = np.zeros((W * H, W),
                                        dtype=training_set.labels.dtype)
        example_real_images = np.zeros((W * H, ) + training_set.shape[1:],
                                       dtype=training_set.dtype)
        for x in xrange(W):
            snapshot_fake_labels[x::W, x] = 1.0
            indices = np.arange(
                training_set.shape[0])[training_set.labels[:, x] != 0]
            for y in xrange(H):
                example_real_images[x + y * W] = training_set.h5_lods[0][
                    np.random.choice(indices)]
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    # Theano input variables and compile generation func.
    print 'Setting up Theano...'
    real_images_var = T.TensorType('float32', [False] *
                                   len(D.input_shape))('real_images_var')
    real_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('real_labels_var')
    fake_latents_var = T.TensorType('float32', [False] *
                                    len(G.input_shape))('fake_latents_var')
    fake_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('fake_labels_var')
    G_lrate = theano.shared(np.float32(0.0))
    D_lrate = theano.shared(np.float32(0.0))
    gen_fn = theano.function([fake_latents_var, fake_labels_var],
                             Gs.eval_nd(fake_latents_var,
                                        fake_labels_var,
                                        ignore_unused_inputs=True),
                             on_unused_input='ignore')

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    if config.D.get('mbdisc_kernels', None):
        print 'Initializing minibatch discrimination...'
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(initial_lod))
        D.eval(real_images_var, deterministic=False, init=True)
        init_layers = lasagne.layers.get_all_layers(D.output_layers)
        init_updates = [
            update for layer in init_layers
            for update in getattr(layer, 'init_updates', [])
        ]
        init_fn = theano.function(inputs=[real_images_var],
                                  outputs=None,
                                  updates=init_updates)
        init_reals = training_set.get_random_minibatch(500, lod=initial_lod)
        init_reals = misc.adjust_dynamic_range(init_reals, drange_orig,
                                               drange_net)
        init_fn(init_reals)
        del init_reals

    # Save example images.
    snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    # Training loop.
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time
    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / 1000.0) / (lod_training_kimg +
                                          lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                           rampdown_kimg)
        G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))
        if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod))
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod))

        # Setup training func for current LOD.
        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            print 'Compiling training funcs...'
            min_lod, max_lod = new_min_lod, new_max_lod

            # Pre-process reals.
            real_images_expr = real_images_var
            if dequantize_reals:
                rnd = theano.sandbox.rng_mrg.MRG_RandomStreams(
                    lasagne.random.get_rng().randint(1, 2147462579))
                epsilon_noise = rnd.uniform(size=real_images_expr.shape,
                                            low=-0.5,
                                            high=0.5,
                                            dtype='float32')
                real_images_expr = T.cast(
                    real_images_expr, 'float32'
                ) + epsilon_noise  # match original implementation of Improved Wasserstein
            real_images_expr = misc.adjust_dynamic_range(
                real_images_expr, drange_orig, drange_net)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=2)
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=3)

            # Optimize loss.
            G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss(
                G, D, min_lod, max_lod, real_images_expr, real_labels_var,
                fake_latents_var, fake_labels_var, **config.loss)
            G_updates = adam(G_loss,
                             G.trainable_params(),
                             learning_rate=G_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()
            D_updates = adam(D_loss,
                             D.trainable_params(),
                             learning_rate=D_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            # Compile training funcs.
            if not separate_funcs:
                GD_train_fn = theano.function([
                    real_images_var, real_labels_var, fake_latents_var,
                    fake_labels_var
                ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                              updates=G_updates + D_updates +
                                              Gs.updates,
                                              on_unused_input='ignore')
            else:
                D_train_fn = theano.function([
                    real_images_var, real_labels_var, fake_latents_var,
                    fake_labels_var
                ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                             updates=D_updates,
                                             on_unused_input='ignore')
                G_train_fn = theano.function(
                    [fake_latents_var, fake_labels_var], [],
                    updates=G_updates + Gs.updates,
                    on_unused_input='ignore')

        # Invoke training funcs.
        if not separate_funcs:
            assert D_training_repeats == 1
            mb_reals, mb_labels = training_set.get_random_minibatch(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)
            mb_train_out = GD_train_fn(
                mb_reals, mb_labels,
                random_latents(minibatch_size, G.input_shape),
                random_labels(minibatch_size, training_set))
            cur_nimg += minibatch_size
            tick_train_out.append(mb_train_out)
        else:
            for idx in xrange(D_training_repeats):
                mb_reals, mb_labels = training_set.get_random_minibatch(
                    minibatch_size,
                    lod=cur_lod,
                    shrink_based_on_lod=True,
                    labels=True)
                mb_train_out = D_train_fn(
                    mb_reals, mb_labels,
                    random_latents(minibatch_size, G.input_shape),
                    random_labels(minibatch_size, training_set))
                cur_nimg += minibatch_size
                tick_train_out.append(mb_train_out)
            G_train_fn(random_latents(minibatch_size, G.input_shape),
                       random_labels(minibatch_size, training_set))

        # Fade in D noise if we're close to becoming unstable
        fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            D.gdrop_strength.set_value(np.float32(gdrop_strength))

        # Perform maintenance operations once per tick.
        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Print progress.
            print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size,
                 misc.format_time(cur_time - train_start_time), tick_time,
                 tick_time / tick_kimg, gdrop_strength) + tick_train_avg)

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = gen_fn(snapshot_fake_latents,
                                              snapshot_fake_labels)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            # Save network snapshot every N ticks.
            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg / 1000)))

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    training_set.close()
    print 'Done.'
    with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'):
        pass
Пример #9
0
def train_gan(separate_funcs=False,
              D_training_repeats=1,
              G_learning_rate_max=0.0010,
              D_learning_rate_max=0.0010,
              G_smoothing=0.999,
              adam_beta1=0.0,
              adam_beta2=0.99,
              adam_epsilon=1e-8,
              minibatch_default=16,
              minibatch_overrides={},
              rampup_kimg=40 / speed_factor,
              rampdown_kimg=0,
              lod_initial_resolution=4,
              lod_training_kimg=400 / speed_factor,
              lod_transition_kimg=400 / speed_factor,
              total_kimg=10000 / speed_factor,
              dequantize_reals=False,
              gdrop_beta=0.9,
              gdrop_lim=0.5,
              gdrop_coef=0.2,
              gdrop_exp=2.0,
              drange_net=[-1, 1],
              drange_viz=[-1, 1],
              image_grid_size=None,
              tick_kimg_default=50 / speed_factor,
              tick_kimg_overrides={
                  32: 20,
                  64: 10,
                  128: 10,
                  256: 5,
                  512: 2,
                  1024: 1
              },
              image_snapshot_ticks=1,
              network_snapshot_ticks=4,
              image_grid_type='default',
              resume_network=None,
              resume_kimg=0.0,
              resume_time=0.0):

    training_set, drange_orig = load_dataset()

    print("-" * 50)
    print("resume_kimg: %s" % resume_kimg)
    print("-" * 50)

    if resume_network:
        print("Resuming weight from:" + resume_network)
        G = Generator(num_channels=training_set.shape[3],
                      resolution=training_set.shape[1],
                      label_size=training_set.labels.shape[1],
                      **config.G)
        D = Discriminator(num_channels=training_set.shape[3],
                          resolution=training_set.shape[1],
                          label_size=training_set.labels.shape[1],
                          **config.D)
        G, D = load_GD_weights(G, D,
                               os.path.join(config.result_dir, resume_network),
                               True)
        print("pre-trained weights loaded!")
        print("-" * 50)
    else:
        G = Generator(num_channels=training_set.shape[3],
                      resolution=training_set.shape[1],
                      label_size=training_set.labels.shape[1],
                      **config.G)
        D = Discriminator(num_channels=training_set.shape[3],
                          resolution=training_set.shape[1],
                          label_size=training_set.labels.shape[1],
                          **config.D)

    G_train, D_train = PG_GAN(G, D, config.G['latent_size'], 0,
                              training_set.shape[1], training_set.shape[3])

    print(G.summary())
    print(D.summary())

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    G_opt = optimizers.Adam(lr=0.0,
                            beta_1=adam_beta1,
                            beta_2=adam_beta2,
                            epsilon=adam_epsilon)
    D_opt = optimizers.Adam(lr=0.0,
                            beta_1=adam_beta1,
                            beta_2=adam_beta2,
                            epsilon=adam_epsilon)

    if config.loss['type'] == 'wass':
        G_loss = wasserstein_loss
        D_loss = wasserstein_loss
    elif config.loss['type'] == 'iwass':
        G_loss = multiple_loss
        D_loss = [mean_loss, 'mse']
        D_loss_weight = [1.0, config.loss['iwass_lambda']]

    G.compile(G_opt, loss=G_loss)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=D_loss_weight)

    cur_nimg = int(resume_kimg * 1000)
    print("-" * 50)
    print("resume_kimg: %s" % resume_kimg)
    print("current nimg: %s" % cur_nimg)
    print("-" * 50)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print("w:%d,h:%d" % (w, h))
            image_grid_size = np.clip(int(1920 // w), 3,
                                      16).astype('int'), np.clip(
                                          1080 / h, 2, 16).astype('int')

        print("image_grid_size:", image_grid_size)

        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch_channel_last(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)

    print("example_real_images.shape:", example_real_images.shape)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)

    snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                           G.input_shape)
    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir,
                                      'fakes%06d.png' % (cur_nimg / 1000)),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    nimg_h = 0

    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg +
                                                           lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                      rampdown_kimg)

        K.set_value(G.optimizer.lr,
                    np.float32(lrate_coef * G_learning_rate_max))
        K.set_value(G_train.optimizer.lr,
                    np.float32(lrate_coef * G_learning_rate_max))

        K.set_value(D_train.optimizer.lr,
                    np.float32(lrate_coef * D_learning_rate_max))
        if hasattr(G_train, 'cur_lod'):
            K.set_value(G_train.cur_lod, np.float32(cur_lod))
        if hasattr(D_train, 'cur_lod'):
            K.set_value(D_train.cur_lod, np.float32(cur_lod))

        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            min_lod, max_lod = new_min_lod, new_max_lod

        # train D
        d_loss = None
        for idx in range(D_training_repeats):
            mb_reals, mb_labels = training_set.get_random_minibatch_channel_last(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)
            mb_latents = random_latents(minibatch_size, G.input_shape)
            mb_labels_rnd = random_labels(minibatch_size, training_set)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1)
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2)

            mb_fakes = G.predict_on_batch([mb_latents])

            epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1))
            interpolation = epsilon * mb_reals + (1 - epsilon) * mb_fakes
            mb_reals = misc.adjust_dynamic_range(mb_reals, drange_orig,
                                                 drange_net)
            d_loss, d_diff, d_norm = D_train.train_on_batch(
                [mb_fakes, mb_reals, interpolation], [
                    np.ones((minibatch_size, 1, 1, 1)),
                    np.ones((minibatch_size, 1))
                ])
            d_score_real = D.predict_on_batch(mb_reals)
            d_score_fake = D.predict_on_batch(mb_fakes)
            print("%d real score: %d fake score: %d" %
                  (idx, np.mean(d_score_real), np.mean(d_score_fake)))
            print("-" * 50)
            cur_nimg += minibatch_size

        #train G

        mb_latents = random_latents(minibatch_size, G.input_shape)
        mb_labels_rnd = random_labels(minibatch_size, training_set)

        g_loss = G_train.train_on_batch([mb_latents], (-1) * np.ones(
            (mb_latents.shape[0], 1, 1, 1)))

        print("%d images [D loss: %f] [G loss: %f]" %
              (cur_nimg, d_loss, g_loss))
        print("-" * 50)

        fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            K.set_value(D.gdrop_strength, np.float32(gdrop_strength))

        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = G.predict_on_batch(
                    snapshot_fake_latents)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                save_GD_weights(
                    G, D,
                    os.path.join(result_subdir,
                                 'network-snapshot-%06d' % (cur_nimg / 1000)))

    save_GD(G, D, os.path.join(result_subdir, 'network-final'))
    training_set.close()
    print('Done.')
Пример #10
0
def process_colors(x):
    with tf.name_scope('ProcessColors'):
        x = misc.adjust_dynamic_range(x, [0, 1], [-1, 1])
    return x
Пример #11
0
def process_masks(x):
    with tf.name_scope('ProcessMasks'):
        x = tf.cast(x, tf.float32)
        x = misc.adjust_dynamic_range(x, [0, 255], [-1, 1])
    return x
Пример #12
0
def train_gan(
        separate_funcs=False,
        D_training_repeats=1,
        G_learning_rate_max=0.0010,
        D_learning_rate_max=0.0010,
        G_smoothing=0.999,
        adam_beta1=0.0,
        adam_beta2=0.99,
        adam_epsilon=1e-8,
        minibatch_default=16,
        minibatch_overrides={},
        rampup_kimg=40 / speed_factor,
        rampdown_kimg=0,
        lod_initial_resolution=4,
        lod_training_kimg=400 / speed_factor,
        lod_transition_kimg=400 / speed_factor,
        total_kimg=10000 / speed_factor,
        dequantize_reals=False,
        gdrop_beta=0.9,
        gdrop_lim=0.5,
        gdrop_coef=0.2,
        gdrop_exp=2.0,
        drange_net=[-1, 1],
        drange_viz=[-1, 1],
        image_grid_size=None,
        tick_kimg_default=50 / speed_factor,
        tick_kimg_overrides={
            32: 20,
            64: 10,
            128: 10,
            256: 5,
            512: 2,
            1024: 1
        },
        image_snapshot_ticks=1,
        network_snapshot_ticks=4,
        image_grid_type='default',
        resume_network='000-celeba/network-snapshot-000488',
        #resume_network          = None,
        resume_kimg=511.6,
        resume_time=0.0):

    training_set, drange_orig = load_dataset()
    #print("training_set.shape:",training_set.shape)

    if resume_network:
        print("Resuming weight from:" + resume_network)
        G = Generator(num_channels=training_set.shape[3],
                      resolution=training_set.shape[1],
                      label_size=training_set.labels.shape[1],
                      **config.G)
        D = Discriminator(num_channels=training_set.shape[3],
                          resolution=training_set.shape[1],
                          label_size=training_set.labels.shape[1],
                          **config.D)
        G, D = load_GD_weights(G, D,
                               os.path.join(config.result_dir, resume_network),
                               True)
    else:
        G = Generator(num_channels=training_set.shape[3],
                      resolution=training_set.shape[1],
                      label_size=training_set.labels.shape[1],
                      **config.G)
        D = Discriminator(num_channels=training_set.shape[3],
                          resolution=training_set.shape[1],
                          label_size=training_set.labels.shape[1],
                          **config.D)
        #missing Gs
    G_train, D_train = PG_GAN(G, D, config.G['latent_size'], 0,
                              training_set.shape[1], training_set.shape[3])
    #G_tmp = Model(inputs=[G.get_input_at(0)],
    #              outputs = [G.get_layer('G6bPN').output])
    print(G.summary())
    print(D.summary())
    #print(pg_GAN.summary())

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    G_opt = optimizers.Adam(lr=0.0,
                            beta_1=adam_beta1,
                            beta_2=adam_beta2,
                            epsilon=adam_epsilon)
    D_opt = optimizers.Adam(lr=0.0,
                            beta_1=adam_beta1,
                            beta_2=adam_beta2,
                            epsilon=adam_epsilon)
    # GAN_opt = optimizers.Adam(lr = 0.0,beta_1 = 0.0,beta_2 = 0.99)

    if config.loss['type'] == 'wass':
        G_loss = wasserstein_loss
        D_loss = wasserstein_loss
    elif config.loss['type'] == 'iwass':
        G_loss = multiple_loss
        D_loss = [mean_loss, 'mse']
        D_loss_weight = [1.0, config.loss['iwass_lambda']]

    G.compile(G_opt, loss=G_loss)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=D_loss_weight)

    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time

    #real_image_input = Input((training_set.shape[2],training_set.shape[2],training_set.shape[-1]),name = "real_image_input")
    #real_label_input = Input((training_set.labels.shape[1]),name = "real_label_input")
    #fake_latent_input = Input((config.G['latent_size']),name = "fake_latent_input")
    #fake_labels_input = Input((training_set.labels.shape[1]),name = "fake_label_input")

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print("w:%d,h:%d" % (w, h))
            image_grid_size = np.clip(int(1920 // w), 3,
                                      16).astype('int'), np.clip(
                                          1080 / h, 2, 16).astype('int')

        print("image_grid_size:", image_grid_size)

        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch_channel_last(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    #print("image_grid_size:",image_grid_size)

    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    #snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    #result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc)

    #('example_real_images.shape:', (120, 3, 128, 128))

    #print("example_real_images:",example_real_images)
    print("example_real_images.shape:", example_real_images.shape)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    #misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_viz, grid_size=image_grid_size)

    #NINweight = G.get_layer('Glod0NIN').get_weights()[0]
    #print("Glod0NIN weight:",NINweight)

    snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                           G.input_shape)
    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir,
                                      'fakes%06d.png' % (cur_nimg / 1000)),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    nimg_h = 0

    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg +
                                                           lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                      rampdown_kimg)
        #G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        K.set_value(G.optimizer.lr,
                    np.float32(lrate_coef * G_learning_rate_max))
        K.set_value(G_train.optimizer.lr,
                    np.float32(lrate_coef * G_learning_rate_max))
        #D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))
        K.set_value(D_train.optimizer.lr,
                    np.float32(lrate_coef * D_learning_rate_max))
        if hasattr(G_train, 'cur_lod'):
            K.set_value(G_train.cur_lod, np.float32(cur_lod))
        if hasattr(D_train, 'cur_lod'):
            K.set_value(D_train.cur_lod, np.float32(cur_lod))

        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            min_lod, max_lod = new_min_lod, new_max_lod

        #    # Pre-process reals.
        #    real_images_expr = real_images_var
        #    if dequantize_reals:
        #        epsilon_noise = K.random_uniform_variable(real_image_input.shape(), low=-0.5, high=0.5, dtype='float32', seed=np.random.randint(1, 2147462579))
        #        epsilon_noise = rnd.uniform(size=real_images_expr.shape, low=-0.5, high=0.5, dtype='float32')
        #        real_images_expr = T.cast(real_images_expr, 'float32') + epsilon_noise # match original implementation of Improved Wasserstein
        #    real_images_expr = misc.adjust_dynamic_range(real_images_expr, drange_orig, drange_net)
        #    if min_lod > 0: # compensate for shrink_based_on_lod
        #        real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=2)
        #        real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=3)
        # train D
        d_loss = None
        for idx in range(D_training_repeats):
            mb_reals, mb_labels = training_set.get_random_minibatch_channel_last(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)
            mb_latents = random_latents(minibatch_size, G.input_shape)
            mb_labels_rnd = random_labels(minibatch_size, training_set)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1)
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2)

            mb_fakes = G.predict_on_batch([mb_latents])

            epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1))
            interpolation = epsilon * mb_reals + (1 - epsilon) * mb_fakes
            mb_reals = misc.adjust_dynamic_range(mb_reals, drange_orig,
                                                 drange_net)
            d_loss, d_diff, d_norm = D_train.train_on_batch(
                [mb_fakes, mb_reals, interpolation], [
                    np.ones((minibatch_size, 1, 1, 1)),
                    np.ones((minibatch_size, 1))
                ])
            d_score_real = D.predict_on_batch(mb_reals)
            d_score_fake = D.predict_on_batch(mb_fakes)
            print("real score: %d fake score: %d" %
                  (np.mean(d_score_real), np.mean(d_score_fake)))
            #d_loss_real = D.train_on_batch(mb_reals, -np.ones((mb_reals.shape[0],1,1,1)))
            #d_loss_fake = D.train_on_batch(mb_fakes, np.ones((mb_fakes.shape[0],1,1,1)))
            #d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
            cur_nimg += minibatch_size

        #train G

        mb_latents = random_latents(minibatch_size, G.input_shape)
        mb_labels_rnd = random_labels(minibatch_size, training_set)

        #if cur_nimg//100 !=nimg_h:
        #    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
        #    print(np.mean(snapshot_fake_images,axis=-1))
        #    G_lod = G_tmp.predict(snapshot_fake_latents)
        #    print("G6bPN:",np.mean(G_lod,axis=-1))
        #    misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d_beforeGtrain.png' % (cur_nimg)), drange=drange_viz, grid_size=image_grid_size)

        g_loss = G_train.train_on_batch([mb_latents], (-1) * np.ones(
            (mb_latents.shape[0], 1, 1, 1)))

        print("%d [D loss: %f] [G loss: %f]" % (cur_nimg, d_loss, g_loss))
        #print(cur_nimg)
        #print(g_loss)
        #print(d_loss)
        # Fade in D noise if we're close to becoming unstable
        #if cur_nimg//100 !=nimg_h:
        #    nimg_h=cur_nimg//100
        #    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
        #    print(np.mean(snapshot_fake_images,axis=-1))
        #    G_lod = G_tmp.predict(snapshot_fake_latents)
        #    print("G6bPN:",np.mean(G_lod,axis=-1))
        #    NINweight = G.get_layer('Glod0NIN').get_weights()[0]
        #    print(NINweight)
        #    misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg)), drange=drange_viz, grid_size=image_grid_size)

        fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            K.set_value(D.gdrop_strength, np.float32(gdrop_strength))

        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []
            '''
            # Print progress.
            print ('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg))
            '''

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = G.predict_on_batch(
                    snapshot_fake_latents)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                save_GD_weights(
                    G, D,
                    os.path.join(result_subdir,
                                 'network-snapshot-%06d' % (cur_nimg / 1000)))

    save_GD(G, D, os.path.join(result_subdir, 'network-final'))
    training_set.close()
    print('Done.')
Пример #13
0
def pre_process(
    imgs,  # Images to pre-process
    coords,  # Coords corresponding
    bboxes,  # Bounding boxes
    #threshold,              #
    drange_imgs,  # Dynamic range for the images (Typically: [0, 255])
    drange_coords,  # Dynamic range for the coordinates (Typically: [0, image.shape[0]])
    drange_net,  # Dynamic range for the network (Typically: [-1, 1])
    mirror_augment=False,  # Should mirror augment be applied?
    random_dw_conv=False,  # Apply a random depthwise convolution to this input image?
    horizontal_flip=False,
):
    with tf.name_scope('ProcessReals'):
        imgs = tf.cast(imgs, tf.float32)
        coords = tf.cast(coords, tf.float32)

        with tf.name_scope('DynamicRange'):
            imgs = misc.adjust_dynamic_range(imgs, drange_imgs, drange_net)
            coords = misc.adjust_dynamic_range(coords, drange_coords,
                                               drange_net)

        if mirror_augment:
            with tf.name_scope('MirrorAugment'):
                s = tf.shape(imgs)
                mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0)
                mask = tf.tile(mask, [1, s[1], s[2], s[3]])
                imgs = tf.where(mask < 0.5, imgs, tf.reverse(imgs, axis=[3]))

        if random_dw_conv:
            with tf.name_scope('RandomDWConv'):
                # Parameters of the augmentation:
                m = 0.5
                filt = 2 * m * tf.random_uniform((3, 3, 3, 1)) - m
                imgs = (tf.nn.depthwise_conv2d(imgs,
                                               filt,
                                               strides=[1, 1, 1, 1],
                                               padding='SAME',
                                               data_format='NCHW') +
                        imgs) / (1 + m)

        if horizontal_flip:
            with tf.name_scope('HorizontalFlip'):
                rand = np.random.random()
                if rand > 0.5:
                    imgs = tf.image.flip_left_right(imgs)
                    mult = tf.constant([-1, 1, -1, 1, -1, 1])
                    coords *= mult

    # net_bboxes = tf.constant(generate_output(64, np.arange(4,17,2)), tf.float32) # TODO: architecture independance
    with tf.name_scope('ProcessBboxes'):
        bboxes = tf.cast(bboxes, tf.float32)
        bboxes_transformed = tf.map_fn(
            lambda bbox: bboxes_fn(bbox, net_bboxes, config.threshold), bboxes)
        refinement = tf.map_fn(lambda bbox: bbox_refinement(bbox, net_bboxes),
                               bboxes)  # TODO: optimize refinement
        # Should refinement be normalized?

        # net_bboxes.shape = [561, 4]
        # refinement.shape = [?, 561, 4]
        # bboxes = tf.map_fn(lambda bbox: IoU(bbox, net_bboxes_area, config.threshold), bboxes)
    # Remove the irrelevent boxes
    # mask = tf.greater(tf.math.reduce_sum(bboxes,axis=-1),0)
    # bboxes = tf.boolean_mask(bboxes, mask)
    # imgs =  tf.boolean_mask(imgs,mask)
    # coords =  tf.boolean_mask(coords,mask)

    return imgs, coords, bboxes_transformed, refinement
Пример #14
0
def train_detector(
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=1,  # Total length of the training, measured in thousands of real images.
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    snapshot_size=16,  # Size of the snapshot image
    snapshot_ticks=2**13,  # Number of images before maintenance
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=1,  # How often to export network snapshots?
    save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()

    # Load the datasets
    training_set = dataset.load_dataset(tfrecord=config.tfrecord_train,
                                        verbose=True,
                                        **config.dataset)
    testing_set = dataset.load_dataset(tfrecord=config.tfrecord_test,
                                       verbose=True,
                                       repeat=False,
                                       shuffle_mb=0,
                                       **config.dataset)
    testing_set_len = len(testing_set)

    # TODO: data augmentation
    # TODO: testing set

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:  # TODO: save methods
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            N = misc.load_pkl(network_pkl)
        else:
            print('Constructing the network...'
                  )  # TODO: better network (like lod-wise network)
            N = tfutil.Network('N',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               **config.N)
    N.print_layers()

    print('Building TensorFlow graph...')
    # Training set up
    with tf.name_scope('Inputs'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        # minibatch_in            = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        reals, labels, bboxes = training_set.get_minibatch_tf(
        )  # TODO: increase the size of the batch by several loss computation and mean
    N_opt = tfutil.Optimizer(name='TrainN',
                             learning_rate=lrate_in,
                             **config.N_opt)

    with tf.device('/gpu:0'):
        reals, labels, gt_outputs, gt_ref = pre_process(
            reals, labels, bboxes, training_set.dynamic_range,
            [0, training_set.shape[-2]], drange_net)
        with tf.name_scope('N_loss'):  # TODO: loss inadapted
            N_loss = tfutil.call_func_by_name(N=N,
                                              reals=reals,
                                              gt_outputs=gt_outputs,
                                              gt_ref=gt_ref,
                                              **config.N_loss)

        N_opt.register_gradients(tf.reduce_mean(N_loss), N.trainables)
    N_train_op = N_opt.apply_updates()

    # Testing set up
    with tf.device('/gpu:0'):
        test_reals_tf, test_labels_tf, test_bboxes_tf = testing_set.get_minibatch_tf(
        )
        test_reals_tf, test_labels_tf, test_gt_outputs_tf, test_gt_ref_tf = pre_process(
            test_reals_tf, test_labels_tf, test_bboxes_tf,
            testing_set.dynamic_range, [0, testing_set.shape[-2]], drange_net)
        with tf.name_scope('N_test_loss'):
            test_loss = tfutil.call_func_by_name(N=N,
                                                 reals=test_reals_tf,
                                                 gt_outputs=test_gt_outputs_tf,
                                                 gt_ref=test_gt_ref_tf,
                                                 is_training=False,
                                                 **config.N_loss)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        N.setup_weight_histograms()

    test_reals, _, test_bboxes = testing_set.get_minibatch_np(snapshot_size)
    misc.save_img_bboxes(test_reals,
                         test_bboxes,
                         os.path.join(result_subdir, 'reals.png'),
                         snapshot_size,
                         adjust_range=False)

    test_reals = misc.adjust_dynamic_range(test_reals,
                                           training_set.dynamic_range,
                                           drange_net)
    test_preds, _ = N.run(test_reals, minibatch_size=snapshot_size)
    misc.save_img_bboxes(test_reals, test_preds,
                         os.path.join(result_subdir, 'fakes.png'),
                         snapshot_size)

    print('Training...')
    if resume_run_id is None:
        tfutil.run(tf.global_variables_initializer())

    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time

    # Choose training parameters and configure training ops.
    sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
    training_set.configure(sched.minibatch)

    _train_loss = 0

    while cur_nimg < total_kimg * 1000:

        # Run training ops.
        # for _ in range(minibatch_repeats):
        _, loss = tfutil.run([N_train_op, N_loss], {lrate_in: sched.N_lrate})
        _train_loss += loss
        cur_nimg += sched.minibatch

        # Perform maintenance tasks once per tick.
        if (cur_nimg >= total_kimg * 1000) or (cur_nimg % snapshot_ticks == 0
                                               and cur_nimg > 0):

            cur_tick += 1
            cur_time = time.time()
            # tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            _train_loss = _train_loss / (cur_nimg - tick_start_nimg)
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            testing_set.configure(sched.minibatch)
            _test_loss = 0
            # testing_set_len = 1 # TMP
            for _ in range(0, testing_set_len, sched.minibatch):
                _test_loss += tfutil.run(test_loss)
            _test_loss /= testing_set_len

            # Report progress. # TODO: improved report display
            print(
                'tick %-5d kimg %-6.1f time %-10s sec/tick %-3.1f maintenance %-7.2f train_loss %.4f test_loss %.4f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/maintenance', maintenance_time),
                   tfutil.autosummary('TrainN/train_loss', _train_loss),
                   tfutil.autosummary('TrainN/test_loss', _test_loss)))

            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            if cur_tick % image_snapshot_ticks == 0:
                test_bboxes, test_refs = N.run(test_reals,
                                               minibatch_size=snapshot_size)
                misc.save_img_bboxes_ref(
                    test_reals, test_bboxes, test_refs,
                    os.path.join(result_subdir,
                                 'fakes%06d.png' % (cur_nimg // 1000)),
                    snapshot_size)
            if cur_tick % network_snapshot_ticks == 0:
                misc.save_pkl(
                    N,
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            _train_loss = 0

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    # misc.save_pkl(N, os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Пример #15
0
def train_gan(
    separate_funcs=False,
    D_training_repeats=1,
    G_learning_rate_max=0.0010,
    D_learning_rate_max=0.0010,
    G_smoothing=0.999,
    adam_beta1=0.0,
    adam_beta2=0.99,
    adam_epsilon=1e-8,
    minibatch_default=16,
    minibatch_overrides={},
    rampup_kimg=40/speed_factor,
    rampdown_kimg=0,
    lod_initial_resolution=4,
    lod_training_kimg=400/speed_factor,
    lod_transition_kimg=400/speed_factor,
    total_kimg=10000/speed_factor,
    dequantize_reals=False,
    gdrop_beta=0.9,
    gdrop_lim=0.5,
    gdrop_coef=0.2,
    gdrop_exp=2.0,
    drange_net=[-1, 1],
    drange_viz=[-1, 1],
    image_grid_size=None,
    tick_kimg_default=50/speed_factor,
    tick_kimg_overrides={32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1},
    image_snapshot_ticks=1,
    network_snapshot_ticks=4,
    image_grid_type='default',
    # resume_network          = '000-celeba/network-snapshot-000488',
    resume_network=None,
    resume_kimg=0.0,
        resume_time=0.0):

    training_set, drange_orig = load_dataset()

    G, G_train,\
        D, D_train = build_trainable_model(resume_network,
                                           training_set.shape[3],
                                           training_set.shape[1],
                                           training_set.labels.shape[1],
                                           adam_beta1, adam_beta2,
                                           adam_epsilon)

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(resolution_log2 -
                      int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()

    train_start_time = tick_start_time - resume_time

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print('w:%d,h:%d' % (w, h))
            image_grid_size = (np.clip(int(1920 // w), 3, 16).astype('int'),
                               np.clip(1080 / h, 2, 16).astype('int'))

        print('image_grid_size:', image_grid_size)

        example_real_images, snapshot_fake_labels = \
            training_set.get_random_minibatch_channel_last(
                np.prod(image_grid_size), labels=True)

        snapshot_fake_latents = random_latents(
            np.prod(image_grid_size), G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    result_subdir = misc.create_result_subdir(
        config.result_dir, config.run_desc)
    result_subdir = Path(result_subdir)

    print('example_real_images.shape:', example_real_images.shape)
    misc.save_image_grid(example_real_images,
                         str(result_subdir / 'reals.png'),
                         drange=drange_orig, grid_size=image_grid_size)

    snapshot_fake_latents = random_latents(
        np.prod(image_grid_size), G.input_shape)
    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
    misc.save_image_grid(snapshot_fake_images,
                         str(result_subdir/f'fakes{cur_nimg//1000:06}.png'),
                         drange=drange_viz, grid_size=image_grid_size)

    # nimg_h = 0

    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0/speed_factor)) / \
                (lod_training_kimg + lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                               (lod_training_kimg + lod_transition_kimg) /
                               lod_transition_kimg,
                               0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2 ** (resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(
            cur_res, tick_kimg_default)

        # Update network config.
        lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= rampdown_linear(cur_nimg /
                                      1000.0, total_kimg, rampdown_kimg)

        K.set_value(G.optimizer.lr, np.float32(
            lrate_coef * G_learning_rate_max))
        K.set_value(G_train.optimizer.lr, np.float32(
            lrate_coef * G_learning_rate_max))

        K.set_value(D_train.optimizer.lr, np.float32(
            lrate_coef * D_learning_rate_max))
        if hasattr(G_train, 'cur_lod'):
            K.set_value(G_train.cur_lod, np.float32(cur_lod))
        if hasattr(D_train, 'cur_lod'):
            K.set_value(D_train.cur_lod, np.float32(cur_lod))

        new_min_lod, new_max_lod = int(
            np.floor(cur_lod)), int(np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            min_lod, max_lod = new_min_lod, new_max_lod

        # train D
        d_loss = None
        for idx in range(D_training_repeats):
            mb_reals, mb_labels = \
                training_set.get_random_minibatch_channel_last(
                    minibatch_size, lod=cur_lod, shrink_based_on_lod=True,
                    labels=True)
            mb_latents = random_latents(minibatch_size, G.input_shape)

            # compensate for shrink_based_on_lod
            if min_lod > 0:
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1)
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2)

            mb_fakes = G.predict_on_batch([mb_latents])

            epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1))
            interpolation = epsilon*mb_reals + (1-epsilon)*mb_fakes
            mb_reals = misc.adjust_dynamic_range(
                mb_reals, drange_orig, drange_net)
            d_loss, d_diff, d_norm = \
                D_train.train_on_batch([mb_fakes, mb_reals, interpolation],
                                       [np.ones((minibatch_size, 1, 1, 1)),
                                        np.ones((minibatch_size, 1))])
            d_score_real = D.predict_on_batch(mb_reals)
            d_score_fake = D.predict_on_batch(mb_fakes)
            print('real score: %d fake score: %d' %
                  (np.mean(d_score_real), np.mean(d_score_fake)))
            cur_nimg += minibatch_size

        # train G
        mb_latents = random_latents(minibatch_size, G.input_shape)

        g_loss = G_train.train_on_batch(
            [mb_latents], (-1)*np.ones((mb_latents.shape[0], 1, 1, 1)))

        print('%d [D loss: %f] [G loss: %f]' % (cur_nimg, d_loss, g_loss))

        fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + \
            fake_score_cur * (1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * \
            (max(fake_score_avg - gdrop_lim, 0.0) ** gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            K.set_value(D.gdrop_strength, np.float32(gdrop_strength))

        is_complete = cur_nimg >= total_kimg * 1000
        is_generate_a_lot = \
            cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000

        if is_generate_a_lot or is_complete:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time

            print(f'tick time: {tick_time}')
            print(f'tick image: {tick_kimg}k')
            tick_start_time = cur_time

            # Visualize generated images.
            is_image_snapshot_ticks = cur_tick % image_snapshot_ticks == 0
            if is_image_snapshot_ticks or is_complete:
                snapshot_fake_images = G.predict_on_batch(
                    snapshot_fake_latents)
                misc.save_image_grid(snapshot_fake_images,
                                     str(result_subdir /
                                         f'fakes{cur_nimg // 1000:06}.png'),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            if cur_tick % network_snapshot_ticks == 0 or is_complete:
                save_GD_weights(G, D, str(
                    result_subdir / f'network-snapshot-{cur_nimg // 1000:06}'))
        break
    save_GD(G, D, str(result_subdir/'network-final'))
    training_set.close()
    print('Done.')

    train_complete_time = time.time()-train_start_time
    print(f'training time: {train_complete_time}')
Пример #16
0
def train_gan(
        separate_funcs=False,
        D_training_repeats=1,
        G_learning_rate_max=0.0010,
        D_learning_rate_max=0.0010,
        G_smoothing=0.999,
        adam_beta1=0.0,
        adam_beta2=0.99,
        adam_epsilon=1e-8,
        minibatch_default=16,
        minibatch_overrides={},
        rampup_kimg=40 / speed_factor,
        rampdown_kimg=0,
        lod_initial_resolution=4,
        lod_training_kimg=400 / speed_factor,
        lod_transition_kimg=400 / speed_factor,
        #lod_training_kimg       = 40,
        #lod_transition_kimg     = 40,
        total_kimg=10000 / speed_factor,
        dequantize_reals=False,
        gdrop_beta=0.9,
        gdrop_lim=0.5,
        gdrop_coef=0.2,
        gdrop_exp=2.0,
        drange_net=[-1, 1],
        drange_viz=[-1, 1],
        image_grid_size=None,
        #tick_kimg_default       = 1,
        tick_kimg_default=50 / speed_factor,
        tick_kimg_overrides={
            32: 20,
            64: 10,
            128: 10,
            256: 5,
            512: 2,
            1024: 1
        },
        image_snapshot_ticks=4,
        network_snapshot_ticks=40,
        image_grid_type='default',
        #resume_network_pkl      = '006-celeb128-progressive-growing/network-snapshot-002009.pkl',
        resume_network_pkl=None,
        resume_kimg=0,
        resume_time=0.0):

    # Load dataset and build networks.
    training_set, drange_orig = load_dataset()

    print "*************** test the format of dataset ***************"
    print training_set
    print drange_orig
    # training_set是dataset模块解析h5之后的对象,
    # drange_orig 为training_set.get_dynamic_range()

    if resume_network_pkl:
        print 'Resuming', resume_network_pkl
        G, D, _ = misc.load_pkl(
            os.path.join(config.result_dir, resume_network_pkl))
    else:
        G = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.G)
        D = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.D)
    Gs = G.create_temporally_smoothed_version(beta=G_smoothing,
                                              explicit_updates=True)

    # G,D对象可以由misc解析pkl之后生成,也可以由network模块构造

    print G
    print D

    #misc.print_network_topology_info(G.output_layers)
    #misc.print_network_topology_info(D.output_layers)

    # Setup snapshot image grid.
    # 设置中途输出图片的格式
    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[3], G.output_shape[2]
            image_grid_size = np.clip(1920 / w, 3,
                                      16), np.clip(1080 / h, 2, 16)
        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    # Theano input variables and compile generation func.
    print 'Setting up Theano...'
    real_images_var = T.TensorType('float32', [False] *
                                   len(D.input_shape))('real_images_var')
    # <class 'theano.tensor.var.TensorVariable'>
    # print type(real_images_var),real_images_var
    real_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('real_labels_var')
    fake_latents_var = T.TensorType('float32', [False] *
                                    len(G.input_shape))('fake_latents_var')
    fake_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('fake_labels_var')
    # 带有_var的均为输入张量
    G_lrate = theano.shared(np.float32(0.0))
    D_lrate = theano.shared(np.float32(0.0))
    # share语法就是用来设定默认值的,返回复制的对象
    gen_fn = theano.function([fake_latents_var, fake_labels_var],
                             Gs.eval_nd(fake_latents_var,
                                        fake_labels_var,
                                        ignore_unused_inputs=True),
                             on_unused_input='ignore')

    # gen_fn 是一个函数,输入为:[fake_latents_var, fake_labels_var],
    #                  输出位:Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True),

    #生成函数

    # Misc init.
    #读入当前分辨率
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    #lod 精细度
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    # Save example images.
    snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    # Training loop.
    # 这里才是主训练入口
    # 注意在训练过程中不会跳出最外层while循环,因此更换分辨率等操作必然在while循环里

    #现有图片数
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0

    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time
    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        #计算当前精细度
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg +
                                                           lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        # 当前分辨率
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        # 更新网络结构
        lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                           rampdown_kimg)
        G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))

        if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod))
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod))

        # Setup training func for current LOD.
        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))

        #print " cur_lod%f\n  min_lod %f\n new_min_lod %f\n max_lod %f\n new_max_lod %f\n"%(cur_lod,min_lod,new_min_lod,max_lod,new_max_lod)

        if min_lod != new_min_lod or max_lod != new_max_lod:
            print 'Compiling training funcs...'
            min_lod, max_lod = new_min_lod, new_max_lod

            # Pre-process reals.
            real_images_expr = real_images_var
            if dequantize_reals:
                rnd = theano.sandbox.rng_mrg.MRG_RandomStreams(
                    lasagne.random.get_rng().randint(1, 2147462579))
                epsilon_noise = rnd.uniform(size=real_images_expr.shape,
                                            low=-0.5,
                                            high=0.5,
                                            dtype='float32')
                real_images_expr = T.cast(
                    real_images_expr, 'float32'
                ) + epsilon_noise  # match original implementation of Improved Wasserstein
            real_images_expr = misc.adjust_dynamic_range(
                real_images_expr, drange_orig, drange_net)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=2)
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=3)

            # Optimize loss.
            G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss(
                G, D, min_lod, max_lod, real_images_expr, real_labels_var,
                fake_latents_var, fake_labels_var, **config.loss)
            G_updates = adam(G_loss,
                             G.trainable_params(),
                             learning_rate=G_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            D_updates = adam(D_loss,
                             D.trainable_params(),
                             learning_rate=D_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            D_train_fn = theano.function([
                real_images_var, real_labels_var, fake_latents_var,
                fake_labels_var
            ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                         updates=D_updates,
                                         on_unused_input='ignore')
            G_train_fn = theano.function([fake_latents_var, fake_labels_var],
                                         [],
                                         updates=G_updates + Gs.updates,
                                         on_unused_input='ignore')

        for idx in xrange(D_training_repeats):
            mb_reals, mb_labels = training_set.get_random_minibatch(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)

            print "******* test minibatch"
            print "mb_reals"
            print idx, D_training_repeats
            print mb_reals.shape, mb_labels.shape
            #print mb_reals
            print "mb_labels"
            #print mb_labels

            mb_train_out = D_train_fn(
                mb_reals, mb_labels,
                random_latents(minibatch_size, G.input_shape),
                random_labels(minibatch_size, training_set))
            cur_nimg += minibatch_size
            tick_train_out.append(mb_train_out)
        G_train_fn(random_latents(minibatch_size, G.input_shape),
                   random_labels(minibatch_size, training_set))

        # Fade in D noise if we're close to becoming unstable
        fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            D.gdrop_strength.set_value(np.float32(gdrop_strength))

        # Perform maintenance operations once per tick.
        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Print progress.
            print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size,
                 misc.format_time(cur_time - train_start_time), tick_time,
                 tick_time / tick_kimg, gdrop_strength) + tick_train_avg)

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = gen_fn(snapshot_fake_latents,
                                              snapshot_fake_labels)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            # Save network snapshot every N ticks.
            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg / 1000)))

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    training_set.close()
    print 'Done.'
    with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'):
        pass
Пример #17
0
def classify(model_path, testing_data_path):

    labels_1 = [
        'CelebA_real_data', 'ProGAN_generated_data', 'SNGAN_generated_data',
        'CramerGAN_generated_data', 'MMDGAN_generated_data'
    ]
    labels_2 = [
        'CelebA_real_data', 'ProGAN_seed_0_generated_data ',
        'ProGAN_seed_1_generated_data', 'ProGAN_seed_2_generated_data',
        'ProGAN_seed_3_generated_data', 'ProGAN_seed_4_generated_data',
        'ProGAN_seed_5_generated_data', 'ProGAN_seed_6_generated_data',
        'ProGAN_seed_7_generated_data', 'ProGAN_seed_8_generated_data',
        'ProGAN_seed_9_generated_data'
    ]

    print('Loading network...')
    C_im = misc.load_network_pkl(model_path)

    if testing_data_path.endswith('.png') or testing_data_path.endswith(
            '.jpg'):
        im = np.array(PIL.Image.open(testing_data_path)).astype(
            np.float32) / 255.0
        if len(im.shape) < 3:
            im = np.dstack([im, im, im])
        if im.shape[2] == 4:
            im = im[:, :, :3]
        if im.shape[0] != 128:
            im = skimage.transform.resize(im, (128, 128))
        im = np.transpose(misc.adjust_dynamic_range(im, [0, 1], [-1, 1]),
                          axes=[2, 0, 1])
        im = np.reshape(im, [1] + list(im.shape))
        logits = C_im.run(im,
                          minibatch_size=1,
                          num_gpus=1,
                          out_dtype=np.float32)
        idx = np.argmax(np.squeeze(logits))
        if logits.shape[1] == len(labels_1):
            labels = list(labels_1)
        elif logits.shape[1] == len(labels_2):
            labels = list(labels_2)
        print('The input image is predicted as being sampled from %s' %
              labels[idx])

    elif os.path.isdir(testing_data_path):
        count_dict = None
        name_list = sorted(os.listdir(testing_data_path))
        length = len(name_list)
        for (count0, name) in enumerate(name_list):
            im = np.array(PIL.Image.open('%s/%s' %
                                         (testing_data_path, name))).astype(
                                             np.float32) / 255.0
            if len(im.shape) < 3:
                im = np.dstack([im, im, im])
            if im.shape[2] == 4:
                im = im[:, :, :3]
            if im.shape[0] != 128:
                im = skimage.transform.resize(im, (128, 128))
            im = np.transpose(misc.adjust_dynamic_range(im, [0, 1], [-1, 1]),
                              axes=[2, 0, 1])
            im = np.reshape(im, [1] + list(im.shape))
            logits = C_im.run(im,
                              minibatch_size=1,
                              num_gpus=1,
                              out_dtype=np.float32)
            idx = np.argmax(np.squeeze(logits))
            if logits.shape[1] == len(labels_1):
                labels = list(labels_1)
            elif logits.shape[1] == len(labels_2):
                labels = list(labels_2)
            if count_dict is None:
                count_dict = {}
                for label in labels:
                    count_dict[label] = 0
            count_dict[labels[idx]] += 1
            print(
                'Classifying %d/%d images: %s: predicted as being sampled from %s'
                % (count0, length, name, labels[idx]))
        for label in labels:
            print(
                'The percentage of images sampled from %s is %d/%d = %.2f%%' %
                (label, count_dict[label], length,
                 float(count_dict[label]) / float(length) * 100.0))