Exemplo n.º 1
0
def training_schedule(cur_nimg,
                      resolution_log2,                   # current resolution_log2
                      num_gpus,
                      lod_initial_resolution=4,
                      lod_training_kimg=600,
                      lod_transition_kimg=600,
                      minibatch_base=4,
                      minibatch_dict={},
                      max_minibatch_per_gpu = {},
                      G_lrate_base=0.001,
                      G_lrate_dict={},
                      D_lrate_base=0.001,
                      D_lrate_dict={},
                      lrate_rampup_kimg=0,
                      tick_kimg_base=60, # note we only use 1/10 of the official implementation. My GPU is too slow...
                      tick_kimg_dict={4: 60, 8:40, 16:20, 32:20, 64:20, 128:20, 256:20, 512:20, 1024:20}
                      ):
    # dnnlib comes with EasyDict
    s = EasyDict()
    s.kimg = cur_nimg  / 1000.0

    phase_dur = lod_training_kimg + lod_transition_kimg
    phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0
    phase_kimg = s.kimg - phase_idx * phase_dur

    # Level-of-detail and resolution.
    s.lod = resolution_log2
    s.lod -= np.floor(np.log2(lod_initial_resolution))
    s.lod -= phase_idx
    if lod_transition_kimg > 0:
        s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg
    s.lod = max(s.lod, 0.0)
    s.resolution = 2 ** (resolution_log2 - int(np.floor(s.lod)))
    s.resolution_log2 = int(np.log2(s.resolution))
    s.alpha = 1 - (s.lod - (resolution_log2 - s.resolution_log2))
    assert 0 <= s.alpha <= 1.0

    # Minibatch size.
    s.minibatch = minibatch_dict.get(s.resolution, minibatch_base)
    s.minibatch -= s.minibatch % num_gpus
    if s.resolution in max_minibatch_per_gpu:
        s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus)

    # Learning rate.
    s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base)
    s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base)
    if lrate_rampup_kimg > 0:
        rampup = min(s.kimg / lrate_rampup_kimg, 1.0)
        s.G_lrate *= rampup
        s.D_lrate *= rampup

    # Other parameters.
    s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base)
    return s
Exemplo n.º 2
0
#desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq');             train.mirror_augment = True
#desc += '-bedroom';  dataset = EasyDict(tfrecord_dir='lsun-bedroom-full');    train.mirror_augment = False
#desc += '-car';      dataset = EasyDict(tfrecord_dir='lsun-car-512x384');     train.mirror_augment = False
#desc += '-cat';      dataset = EasyDict(tfrecord_dir='lsun-cat-full');        train.mirror_augment = False

# Number of GPUs.
#desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_size = 4
#desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_size = 32
desc += '-4gpu'
submit_config.num_gpus = 4
sched.minibatch_size = 16
#desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_size = 32

# Default options.
train.total_kimg = 25000
sched.G_lrate = 0.003
sched.D_lrate = sched.G_lrate

# related to frequency of logs:
sched.tick_kimg = 10
image_snapshot_ticks = 1
network_snapshot_ticks = 10

# debug ones:
# sched.tick_kimg = 0.001
# image_snapshot_ticks = 1
# network_snapshot_ticks = 1

# WGAN-GP loss for CelebA-HQ.
# desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)