Beispiel #1
0
def training_schedule(cur_nimg,
                      resolution_log2,                   # current resolution_log2
                      num_gpus,
                      lod_initial_resolution=4,
                      lod_training_kimg=600,
                      lod_transition_kimg=600,
                      minibatch_base=4,
                      minibatch_dict={},
                      max_minibatch_per_gpu = {},
                      G_lrate_base=0.001,
                      G_lrate_dict={},
                      D_lrate_base=0.001,
                      D_lrate_dict={},
                      lrate_rampup_kimg=0,
                      tick_kimg_base=60, # note we only use 1/10 of the official implementation. My GPU is too slow...
                      tick_kimg_dict={4: 60, 8:40, 16:20, 32:20, 64:20, 128:20, 256:20, 512:20, 1024:20}
                      ):
    # dnnlib comes with EasyDict
    s = EasyDict()
    s.kimg = cur_nimg  / 1000.0

    phase_dur = lod_training_kimg + lod_transition_kimg
    phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0
    phase_kimg = s.kimg - phase_idx * phase_dur

    # Level-of-detail and resolution.
    s.lod = resolution_log2
    s.lod -= np.floor(np.log2(lod_initial_resolution))
    s.lod -= phase_idx
    if lod_transition_kimg > 0:
        s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg
    s.lod = max(s.lod, 0.0)
    s.resolution = 2 ** (resolution_log2 - int(np.floor(s.lod)))
    s.resolution_log2 = int(np.log2(s.resolution))
    s.alpha = 1 - (s.lod - (resolution_log2 - s.resolution_log2))
    assert 0 <= s.alpha <= 1.0

    # Minibatch size.
    s.minibatch = minibatch_dict.get(s.resolution, minibatch_base)
    s.minibatch -= s.minibatch % num_gpus
    if s.resolution in max_minibatch_per_gpu:
        s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus)

    # Learning rate.
    s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base)
    s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base)
    if lrate_rampup_kimg > 0:
        rampup = min(s.kimg / lrate_rampup_kimg, 1.0)
        s.G_lrate *= rampup
        s.D_lrate *= rampup

    # Other parameters.
    s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base)
    return s