Beispiel #1
0
def D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument
    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.
    wgan_target     = 1.0,      # Target value for gradient magnitudes.
    cond_weight     = 1.0):    

    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
    real_scores_out = tfutil.autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/scores/fake', fake_scores_out)
    loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out)

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
        mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
        mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True))
        mixed_scores_out = tfutil.autosummary('Loss/scores/mixed', mixed_scores_out)
        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
        mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
        mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    return loss
Beispiel #2
0
def D_hrvgan(
    G,
    D,
    R,
    opt,
    training_set,
    minibatch_size,
    reals,
    labels,
    wgan_lambda=10.0,  # Weight for the gradient penalty term.
    wgan_epsilon=0.001,  # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target=1.0,  # Target value for gradient magnitudes.
    cond_weight=1.0):  # Weight of the conditioning terms.
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out, real_labels_out = fp32(
        D.get_output_for(reals, is_training=True))
    fake_scores_out, fake_labels_out = fp32(
        D.get_output_for(fake_images_out, is_training=True))
    real_proj_out, _ = R.get_output_for(real_scores_out)
    real_proj_out = tf.reduce_mean(real_proj_out, [1])
    fake_proj_out, _ = R.get_output_for(fake_scores_out)
    fake_proj_out = tf.reduce_mean(fake_proj_out, [1])
    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/fake_scores', fake_scores_out)
    loss = real_proj_out - fake_proj_out
    loss += (5 * gradient_penalty + features_penalty) * (wgan_lambda /
                                                         (wgan_target**2))
    return loss
Beispiel #3
0
def D_wgangp(G, D, opt, training_set, minibatch_size, reals, labels,
    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.
    wgan_epsilon    = 0.001,    # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target     = 1.0):     # Target value for gradient magnitudes.

    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/fake_scores', fake_scores_out)
    loss = fake_scores_out - real_scores_out

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
        mixed_images_out = tfutil.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
        mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True))
        mixed_scores_out = tfutil.autosummary('Loss/mixed_scores', mixed_scores_out)
        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
        mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
        mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = tfutil.autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
    loss += epsilon_penalty * wgan_epsilon

    return loss
Beispiel #4
0
def D_logistic_r(G, D, opt, training_set, minibatch_size, reals, labels, gamma_1=10.0, gamma_2=0.0):
    _ = opt, training_set
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = D.get_output_for(reals, labels, is_training=True)
    fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
    real_scores_out = tfutil.autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/scores/fake', fake_scores_out)
    loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
    loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type

    if gamma_1 != 0.0:
        with tf.name_scope('R1_GradientPenalty'):
            real_grads = tf.gradients(tf.reduce_sum(real_scores_out), [reals])[0]
            gradient_penalty_1 = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3])
            gradient_penalty_1 = tfutil.autosummary('Loss/gradient_penalty_r1', gradient_penalty_1)
            reg_1 = gradient_penalty_1* (gamma_1 * 0.5)
        loss += reg_1

    if gamma_2 != 0.0:
        with tf.name_scope('R2_GradientPenalty'):
            fake_grads = tf.gradients(tf.reduce_sum(fake_scores_out), [fake_images_out])[0]
            gradient_penalty_2 = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3])
            gradient_penalty_2 = tfutil.autosummary('Loss/gradient_penalty_r2', gradient_penalty_2)
            reg_2 = gradient_penalty_2 * (gamma_2 * 0.5)
        loss += reg_2

    return loss
Beispiel #5
0
def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels, cond_weight=1.0): # pylint: disable=unused-argument
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
    real_scores_out = tfutil.autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/scores/fake', fake_scores_out)
    loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out)

    return loss
Beispiel #6
0
def D_gen_wgangp(
    E_zg,
    E_zl,
    G,
    D_gen,
    D_gen_opt,
    minibatch_size,
    reals_fade,
    wgan_lambda=10.0,  # Weight for the gradient penalty term.
    wgan_epsilon=0.001,  # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target=1.0):  # Target value for gradient magnitudes.

    # random generated realism
    zg_latents = tf.random_normal([minibatch_size] + E_zg.output_shapes[0][1:])
    zl_latents = tf.random_normal([minibatch_size] + E_zl.output_shapes[0][1:])
    fake_images_out = G.get_output_for(
        tf.tile(zg_latents, [1, 1] + E_zl.output_shapes[0][2:]), zl_latents)
    fake_scores_out = fp32(D_gen.get_output_for(fake_images_out))
    real_scores_out = fp32(D_gen.get_output_for(reals_fade))
    gen_D_loss = tf.reduce_mean(fake_scores_out - real_scores_out,
                                axis=[1, 2, 3])
    gen_D_loss = tfutil.autosummary('Loss/gen_D_loss', gen_D_loss)
    loss = tf.identity(gen_D_loss)

    # gradient penalty
    with tf.name_scope('gen_GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1],
                                           0.0,
                                           1.0,
                                           dtype=fake_images_out.dtype)
        mixed_images_out = tfutil.lerp(
            tf.cast(reals_fade, fake_images_out.dtype), fake_images_out,
            mixing_factors)
        mixed_scores_out = fp32(D_gen.get_output_for(mixed_images_out))
        mixed_loss = D_gen_opt.apply_loss_scaling(
            tf.reduce_sum(mixed_scores_out))
        mixed_grads = D_gen_opt.undo_loss_scaling(
            fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(
            tf.reduce_sum(tf.square(mixed_grads), axis=[1, 2, 3]))
        gen_gradient_penalty = tf.square(mixed_norms - wgan_target)
        gen_gradient_penalty *= (wgan_lambda / (wgan_target**2))
        gen_gradient_penalty = tfutil.autosummary('Loss/gen_gradient_penalty',
                                                  gen_gradient_penalty)
    loss += gen_gradient_penalty

    # calibration penalty
    with tf.name_scope('gen_EpsilonPenalty'):
        gen_epsilon_penalty = tf.reduce_mean(tf.square(real_scores_out),
                                             axis=[1, 2, 3]) * wgan_epsilon
        gen_epsilon_penalty = tfutil.autosummary('Loss/gen_epsilon_penalty',
                                                 gen_epsilon_penalty)
    loss += gen_epsilon_penalty

    return loss
Beispiel #7
0
def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels):
    _ = opt, training_set
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = D.get_output_for(reals, labels, is_training=True)
    fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
    real_scores_out = tfutil.autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/scores/fake', fake_scores_out)
    loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
    loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type

    return loss
Beispiel #8
0
def D_wgangp_acgan(G, D, opt, training_set, minibatch_size, reals, labels, embeddings,
    use_embedding   = True,
    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.
    wgan_epsilon    = 0.001,    # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target     = 1.0,      # Target value for gradient magnitudes.
    cond_weight     = 1.0):     # Weight of the conditioning terms.

    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, embeddings, is_training=True)
    if(use_embedding):
        real_scores_out, real_labels_out, real_embeddings_out = fp32(D.get_output_for(reals, labels, embeddings, is_training=True))
        fake_scores_out, fake_labels_out, fake_embeddings_out = fp32(D.get_output_for(fake_images_out, labels, embeddings, is_training=True))
    else:
        real_scores_out, real_labels_out = fp32(D.get_output_for(reals, labels, embeddings, is_training=True))
        fake_scores_out, fake_labels_out = fp32(D.get_output_for(fake_images_out, labels, embeddings, is_training=True))
    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/fake_scores', fake_scores_out)
    loss = fake_scores_out - real_scores_out

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
        mixed_images_out = tfutil.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
        if(use_embedding):
            mixed_scores_out, mixed_labels_out, mixed_embeddings_out = fp32(D.get_output_for(mixed_images_out, labels, embeddings, is_training=True))
        else:
            mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, embeddings, is_training=True))
        mixed_scores_out = tfutil.autosummary('Loss/mixed_scores', mixed_scores_out)
        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
        mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
        mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = tfutil.autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
    loss += epsilon_penalty * wgan_epsilon

    if D.output_shapes[1][1] > 0:
        
        if(use_embedding):
            with tf.name_scope('LabelPenalty'):
                #label_penalty_reals = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=real_labels_out)
                #label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=fake_labels_out)
                label_penalty_reals = tf.losses.mean_squared_error(labels, real_labels_out)
                label_penalty_fakes = tf.losses.mean_squared_error(labels, fake_labels_out)
                label_penalty_reals = tfutil.autosummary('Loss/label_penalty_reals', label_penalty_reals)
                label_penalty_fakes = tfutil.autosummary('Loss/label_penalty_fakes', label_penalty_fakes)
            loss += (label_penalty_reals + label_penalty_fakes) * cond_weight
            with tf.name_scope('EmbeddingPenalty'):
                embedding_penalty_reals = tf.losses.mean_squared_error(embeddings, real_embeddings_out)
                embedding_penalty_fakes = tf.losses.mean_squared_error(embeddings, fake_embeddings_out)
                embedding_penalty_reals = tfutil.autosummary('Loss/embedding_penalty_reals', embedding_penalty_reals)
                embedding_penalty_fakes = tfutil.autosummary('Loss/embedding_penalty_fakes', embedding_penalty_fakes)
            loss += (embedding_penalty_reals + embedding_penalty_fakes) * cond_weight

    return loss
Beispiel #9
0
def D_lsgan(G, D, opt, training_set, minibatch_size, reals, labels):
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out, real_labels_out = fp32(
        D.get_output_for(reals, is_training=True))
    fake_scores_out, fake_labels_out = fp32(
        D.get_output_for(fake_images_out, is_training=True))
    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/fake_scores', fake_scores_out)

    loss = tf.reduce_sum(
        tf.square(real_scores_out - 1) + tf.square(fake_scores_out)) / 2
    return loss
Beispiel #10
0
def D_rec_wgangp(
    EG,
    D_rec,
    D_rec_opt,
    minibatch_size,
    reals_orig,
    wgan_lambda=10.0,  # Weight for the gradient penalty term.
    wgan_epsilon=0.001,  # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target=1.0):  # Target value for gradient magnitudes.

    # reconstructed realism
    recs_out, fingerprints_out, logits_out = EG.get_output_for(reals_orig)
    rec_scores_out = fp32(D_rec.get_output_for(recs_out))
    real_scores_out = fp32(D_rec.get_output_for(reals_orig))
    rec_D_loss = tf.reduce_mean(rec_scores_out - real_scores_out,
                                axis=[1, 2, 3])
    rec_D_loss = tfutil.autosummary('Loss/rec_D_loss', rec_D_loss)
    loss = tf.identity(rec_D_loss)

    # gradient penalty
    with tf.name_scope('rec_GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1],
                                           0.0,
                                           1.0,
                                           dtype=recs_out.dtype)
        mixed_images_out = tfutil.lerp(tf.cast(reals_orig, recs_out.dtype),
                                       recs_out, mixing_factors)
        mixed_scores_out = fp32(D_rec.get_output_for(mixed_images_out))
        mixed_loss = D_rec_opt.apply_loss_scaling(
            tf.reduce_sum(mixed_scores_out))
        mixed_grads = D_rec_opt.undo_loss_scaling(
            fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(
            tf.reduce_sum(tf.square(mixed_grads), axis=[1, 2, 3]))
        rec_gradient_penalty = tf.square(mixed_norms - wgan_target)
        rec_gradient_penalty *= (wgan_lambda / (wgan_target**2))
        rec_gradient_penalty = tfutil.autosummary('Loss/rec_gradient_penalty',
                                                  rec_gradient_penalty)
    loss += rec_gradient_penalty

    # calibration penalty
    with tf.name_scope('rec_EpsilonPenalty'):
        rec_epsilon_penalty = tf.reduce_mean(tf.square(real_scores_out),
                                             axis=[1, 2, 3]) * wgan_epsilon
        rec_epsilon_penalty = tfutil.autosummary('Loss/rec_epsilon_penalty',
                                                 rec_epsilon_penalty)
    loss += rec_epsilon_penalty

    return loss
Beispiel #11
0
 def addSinuosityPenalty(index):
     SinuosityPenalty = tf.nn.l2_loss(labels[:, index] -
                                      fake_labels_out[:, index])
     SinuosityPenalty = tfutil.autosummary('Loss_G/SinuosityPenalty',
                                           SinuosityPenalty)
     SinuosityPenalty = SinuosityPenalty * Sinuosity_weight
     return loss + SinuosityPenalty
Beispiel #12
0
 def addWidthPenalty(index):
     WidthPenalty = tf.nn.l2_loss(labels[:, index] -
                                  fake_labels_out[:, index])
     WidthPenalty = tfutil.autosummary('Loss_G/WidthPenalty',
                                       WidthPenalty)
     WidthPenalty = WidthPenalty * Width_weight
     return loss + WidthPenalty
Beispiel #13
0
def G_wgan_acgan(G, D, opt, training_set, minibatch_size,faceB,faceA,
    cond_weight = 1.0): # Weight of the conditioning term.

    labels = training_set.get_random_labels_tf(minibatch_size)

    fake_images_out = G.get_output_for(faceA, labels, is_training=True)

    tmp_fake_images_out = tf.concat([fake_images_out,faceA],axis=1)

    reals = tf.concat([faceB, faceA], axis=1)

    fake_scores_out, fake_labels_out = fp32(D.get_output_for(tmp_fake_images_out, is_training=True))

    L1_Reals = misc.adjust_dynamic_tfrange(reals,[-1,1],[0,255])

    L1_Fake = misc.adjust_dynamic_tfrange(tmp_fake_images_out,[-1,1],[0,255])

    L1_Loss = tf.reduce_mean(tf.abs(tf.add(L1_Reals,-L1_Fake)))

    L1_Loss = tfutil.autosummary('Loss/L1_loss', L1_Loss)

    loss = -fake_scores_out + L1_Loss

    if D.output_shapes[1][1] > 0:
        with tf.name_scope('LabelPenalty'):
            label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=fake_labels_out)
        loss += (-label_penalty_fakes )* cond_weight

    return loss,tmp_fake_images_out
Beispiel #14
0
 def addMudPropPenalty(index):
     MudPropPenalty = tf.nn.l2_loss(labels[:, index] -
                                    fake_labels_out[:, index])
     MudPropPenalty = tfutil.autosummary('Loss_G/MudPropPenalty',
                                         MudPropPenalty)
     MudPropPenalty = MudPropPenalty * MudProp_weight
     return loss + MudPropPenalty
Beispiel #15
0
def G_wgan_acgan(G, D, opt, training_set, minibatch_size,
    cond_weight = 1.0): # Weight of the conditioning term.

    #int(np.log2(G.output_shapes[-1]))
    total_latent_size = G.input_shapes[0][1:][0]
    c_size = 1 + 10 + 1
    random_latent_size = total_latent_size - c_size
    c_3_ind = tf.random_normal([minibatch_size], 0, 1, dtype = tf.float32)
#     c_3 = tf.one_hot(c_3_ind, 2)
    c_4_ind = tf.random_uniform([minibatch_size], 0, 10, dtype = tf.int32)
    c_4 = tf.one_hot(c_4_ind, 10)
    c_5_ind = tf.random_normal([minibatch_size], 0, 1, dtype = tf.float32)
#     c_5 = tf.one_hot(c_5_ind, 2)
    
    test = tf.random_uniform([minibatch_size], 0, 1, dtype = tf.float32)
    c_3 = tf.reshape(c_3_ind, [minibatch_size, 1])
#     c_4 = c_4_ind
    c_5 = tf.reshape(c_5_ind, [minibatch_size, 1])
    
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    labels = training_set.get_random_labels_tf(minibatch_size)
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    fake_scores_out, fake_labels_out, qf3, qf4, qf5, lod_in = fp32(D.get_output_for(fake_images_out, is_training=True))
    loss = -fake_scores_out
    
    loss3 = tf.losses.mean_squared_error(c_3, qf3)
    #loss = tf.Print(loss, [loss3], message="loss3")
    loss4 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=c_4, logits=qf4)
    #loss = tf.Print(loss, [tf.reduce_mean(loss4)], message="loss4")
    loss5 = tf.losses.mean_squared_error(c_5, qf5)
    #loss = tf.Print(loss, [loss5], message="loss5")
    #loss = tf.Print(loss, [tf.reduce_mean(loss), loss3, tf.reduce_mean(loss4), loss5, tf.reduce_mean(loss + loss3 +loss4+loss5)], message="loss")
    
    
    #print(loss.shape, loss3.shape, loss4.shape, loss5.shape, type(loss), type(loss3))
    loss = loss + tf.clip_by_value((1 - lod_in), 0.0, 1.0)*(5 * loss3 + 5 * loss5)
    #loss = loss + 2*loss3 + 0.2*loss4 + 2*loss5
    if D.output_shapes[1][1] > 0:
        with tf.name_scope('LabelPenalty'):
            label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=fake_labels_out)
        loss += label_penalty_fakes * cond_weight
    loss = tfutil.autosummary('Loss/GInfoLoss3', loss3)
    loss = tfutil.autosummary('Loss/GInfoLoss4', loss4)
    loss = tfutil.autosummary('Loss/GInfoLoss5', loss5)    
    loss = tfutil.autosummary('Loss/GFinalLoss', loss)
    return loss, loss3+loss4+loss5
Beispiel #16
0
def D_wgangp_acgan(G, D, opt, training_set, minibatch_size, reals, labels,
    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.
    wgan_epsilon    = 0.001,    # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target     = 1.0,      # Target value for gradient magnitudes.
    cond_weight     = 1.0):     # Weight of the conditioning terms.
    
    total_latent_size = G.input_shapes[0][1:][0]
    c_size = 1 + 10 + 1
    random_latent_size = total_latent_size - c_size
    c_3_ind = tf.random_normal([minibatch_size], 0, 1, dtype = tf.float32)
#     c_3 = tf.one_hot(c_3_ind, 2)
    c_4_ind = tf.random_uniform([minibatch_size], 0, 10, dtype = tf.int32)
    c_4 = tf.one_hot(c_4_ind, 10)
    c_5_ind = tf.random_normal([minibatch_size], 0, 1, dtype = tf.float32)
#     c_5 = tf.one_hot(c_5_ind, 2)
    
    test = tf.random_uniform([minibatch_size], 0, 1, dtype = tf.float32)
    c_3 = tf.reshape(c_3_ind, [minibatch_size, 1])
#     c_4 = c_4_ind
    c_5 = tf.reshape(c_5_ind, [minibatch_size, 1])
    
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    
    #latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out, real_labels_out, qr3, qr4, qr5, lod_in = fp32(D.get_output_for(reals, is_training=True))
    fake_scores_out, fake_labels_out, qf3, qf4, qf5, lod_in = fp32(D.get_output_for(fake_images_out, is_training=True))
    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/fake_scores', fake_scores_out)
    loss = fake_scores_out - real_scores_out

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
        mixed_images_out = tfutil.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
        mixed_scores_out, mixed_labels_out, q3, q4, q5, lod_in = fp32(D.get_output_for(mixed_images_out, is_training=True))
        mixed_scores_out = tfutil.autosummary('Loss/mixed_scores', mixed_scores_out)
        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
        mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
        mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = tfutil.autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
    loss += epsilon_penalty * wgan_epsilon

    if D.output_shapes[1][1] > 0:
        with tf.name_scope('LabelPenalty'):
            label_penalty_reals = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=real_labels_out)
            label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=fake_labels_out)
            label_penalty_reals = tfutil.autosummary('Loss/label_penalty_reals', label_penalty_reals)
            label_penalty_fakes = tfutil.autosummary('Loss/label_penalty_fakes', label_penalty_fakes)
        loss += (label_penalty_reals + label_penalty_fakes) * cond_weight
    loss = tfutil.autosummary('Loss/DFinalLoss', loss)    
    return loss
def G_wgan_acgan(G, D, opt, training_set, minibatch_size): 

    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    labels = training_set.get_random_labels_tf(minibatch_size)
    fake_images_out = G.get_output_for(latents, labels, is_training=True) 
    fake_scores_out, fake_labels_out = fp32(D.get_output_for(fake_images_out, is_training=True))
    loss = -fake_scores_out
    loss = tfutil.autosummary('Loss_G/Total_loss', loss)
    return loss
Beispiel #18
0
def C_classification(C_im, reals_orig, labels):

    with tf.name_scope('ClassificationPenalty'):
        real_labels_out = fp32(C_im.get_output_for(reals_orig))
        real_class_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=labels, logits=real_labels_out, dim=1)
        real_class_loss = tfutil.autosummary('Loss/real_class_loss',
                                             real_class_loss)
    loss = tf.identity(real_class_loss)

    return loss
 def addWidthPenalty(index):
     WidthPenalty = tf.nn.l2_loss(
         labels_in[:, index] - fake_labels_out[:, index]
     )  # [:,0] is the inter-channel mud facies ratio
     if lossnorm:
         WidthPenalty = (WidthPenalty - 0.600282781464712
                         ) / 0.270670509379704  # To normalize this loss
     WidthPenalty = tfutil.autosummary('Loss_G/WidthPenalty',
                                       WidthPenalty)
     WidthPenalty = WidthPenalty * Width_weight
     return loss + WidthPenalty
 def addSinuosityPenalty(index):
     SinuosityPenalty = tf.nn.l2_loss(
         labels_in[:, index] - fake_labels_out[:, index]
     )  # [:,0] is the inter-channel mud facies ratio
     if lossnorm:
         SinuosityPenalty = (
             SinuosityPenalty - 0.451279248935835
         ) / 0.145642580091667  # To normalize this loss
     SinuosityPenalty = tfutil.autosummary('Loss_G/SinuosityPenalty',
                                           SinuosityPenalty)
     SinuosityPenalty = SinuosityPenalty * Sinuosity_weight
     return loss + SinuosityPenalty
 def addMudPropPenalty(index):
     MudPropPenalty = tf.nn.l2_loss(
         labels_in[:, index] - fake_labels_out[:, index]
     )  # [:,0] is the inter-channel mud facies ratio
     if lossnorm:
         MudPropPenalty = (
             MudPropPenalty - 0.36079434843794
         ) / 0.11613414177144  # To normalize this loss
     MudPropPenalty = tfutil.autosummary('Loss_G/MudPropPenalty',
                                         MudPropPenalty)
     MudPropPenalty = MudPropPenalty * MudProp_weight
     return loss + MudPropPenalty
 def addwellfaciespenalty(well_facies, fake_images_out, loss,
                          Wellfaciesloss_weight):
     with tf.name_scope('WellfaciesPenalty'):
         WellfaciesPenalty = Wellpoints_L2loss(
             well_facies, fake_images_out
         )  # as levee is 0.5, in well_facies data, levee and channels' codes are all 1
         if lossnorm:
             WellfaciesPenalty = (WellfaciesPenalty -
                                  0.00887323171768953) / 0.00517647244943928
         WellfaciesPenalty = tfutil.autosummary('Loss_G/WellfaciesPenalty',
                                                WellfaciesPenalty)
         loss += WellfaciesPenalty * Wellfaciesloss_weight
     return loss
def D_wgangp_acgan(G, D, opt, training_set, minibatch_size, reals, labels,
    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.
    wgan_epsilon    = 0.001,    # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target     = 1.0,      # Target value for gradient magnitudes.
    cond_weight     = 1.0,      # Weight of the conditioning terms.
    shared=False):

    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    if shared:
        real_scores_out, real_labels_out = fp32(D.get_output_for(*reals, is_training=True))
        fake_scores_out, fake_labels_out = fp32(D.get_output_for(*fake_images_out, is_training=True))
        out_dtype = fake_images_out[0].dtype
    else:
        real_scores_out, real_labels_out = fp32(D.get_output_for(reals, is_training=True))
        fake_scores_out, fake_labels_out = fp32(D.get_output_for(fake_images_out, is_training=True))
        out_dtype = fake_images_out.dtype
    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/fake_scores', fake_scores_out)
    loss = fake_scores_out - real_scores_out

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=out_dtype)
        if shared:
            mixed_images_out = tuple(tfutil.lerp(tf.cast(_reals, out_dtype), _fake_images_out, mixing_factors) for _reals, _fake_images_out in zip(reals, fake_images_out))
            mixed_scores_out, mixed_labels_out = fp32(D.get_output_for(*mixed_images_out, is_training=True))
        else:
            mixed_images_out = tfutil.lerp(tf.cast(reals, out_dtype), fake_images_out, mixing_factors)
            mixed_scores_out, mixed_labels_out = fp32(D.get_output_for(mixed_images_out, is_training=True))
        mixed_scores_out = tfutil.autosummary('Loss/mixed_scores', mixed_scores_out)
        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
        if shared:
            mixed_grads = opt.undo_loss_scaling(fp32(tf.concat([tf.reshape(x, (minibatch_size, -1)) for x in tf.gradients([mixed_loss], list(mixed_images_out)) if x is not None], axis=1)))
            mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1]))
        else:
            mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
            mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
        mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = tfutil.autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
    loss += epsilon_penalty * wgan_epsilon

    if D.output_shapes[1][1] > 0:
        with tf.name_scope('LabelPenalty'):
            label_penalty_reals = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=real_labels_out)
            label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=fake_labels_out)
            label_penalty_reals = tfutil.autosummary('Loss/label_penalty_reals', label_penalty_reals)
            label_penalty_fakes = tfutil.autosummary('Loss/label_penalty_fakes', label_penalty_fakes)
        loss += (label_penalty_reals + label_penalty_fakes) * cond_weight
    return loss
Beispiel #24
0
def EG_classification(EG, D_rec, reals_orig, labels, rec_weight, rec_G_weight):

    recs_out, fingerprints_out, logits_out = EG.get_output_for(reals_orig)
    real_class_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=labels, logits=logits_out, dim=1)
    real_class_loss = tfutil.autosummary('Loss/real_class_loss',
                                         real_class_loss)
    loss = tf.identity(real_class_loss)

    if rec_weight > 0.0:
        rec_loss = tf.reduce_mean(tf.abs(recs_out - reals_orig),
                                  axis=[1, 2, 3])
        rec_loss *= rec_weight
        rec_loss = tfutil.autosummary('Loss/rec_loss', rec_loss)
        loss += rec_loss

    if rec_G_weight > 0.0:
        rec_scores_out = fp32(D_rec.get_output_for(recs_out))
        rec_G_loss = tf.reduce_mean(-rec_scores_out, axis=[1, 2, 3])
        rec_G_loss *= rec_G_weight
        rec_G_loss = tfutil.autosummary('Loss/rec_G_loss', rec_G_loss)
        loss += rec_G_loss

    return loss
 def addfaciescodedistributionloss(probs, fakes, weight, batchsize, relzs,
                                   loss):  # used when resolution is 64x64
     with tf.name_scope('ProbimagePenalty'):
         # In paper, only probability map for channel complex is condisered. If multiple probability maps for multiple facies are considered, needs to calculate channelindicator and probPenalty for each facies.
         channelindicator = 1 / (
             1 + tf.math.exp(-16 * (fakes + 0.5))
         )  # use adjusted sigmoid function as an continous indicatorization.
         probs_fake = tf.reduce_mean(
             tf.reshape(channelindicator,
                        ([batchsize, relzs] + G.input_shapes[2][1:])), 1)
         ProbPenalty = tf.nn.l2_loss(probs - probs_fake)  # L2 loss
         if lossnorm:
             ProbPenalty = ((ProbPenalty * tf.cast(relzs, tf.float32)) -
                            19134) / 5402  # normalize
         ProbPenalty = tfutil.autosummary('Loss_G/ProbPenalty', ProbPenalty)
     loss += ProbPenalty * weight
     return loss
Beispiel #26
0
def D_wgangp_acgan(G, D, opt, training_set, minibatch_size, faceB, labels,faceA,generate_y,varforgauss,
    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.
    wgan_epsilon    = 0.001,    # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target     = 1.0,      # Target value for gradient magnitudes.
    cond_weight     = 1.0):     # Weight of the conditioning terms.

    tmpreals = tf.concat([faceB, faceA], axis=1)

    real_scores_out, real_labels_out = fp32(D.get_output_for(tmpreals, is_training=True))

    fake_scores_out, fake_labels_out = fp32(D.get_output_for(generate_y, is_training=True))

    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)

    fake_scores_out = tfutil.autosummary('Loss/fake_scores', 1 * fake_scores_out )

    loss = fake_scores_out - real_scores_out

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=generate_y.dtype)
        mixed_images_out = tfutil.lerp(1 * tf.cast(tmpreals, generate_y.dtype) , generate_y, mixing_factors)
        mixed_scores_out, mixed_labels_out = fp32(D.get_output_for(mixed_images_out, is_training=True))
        mixed_scores_out = tfutil.autosummary('Loss/mixed_scores', mixed_scores_out)
        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
        mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
        mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)

    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = tfutil.autosummary('Loss/epsilon_penalty', (1 * tf.square(real_scores_out)))
    loss += epsilon_penalty * wgan_epsilon

    if D.output_shapes[1][1] > 0:
        with tf.name_scope('LabelPenalty'):
            label_penalty_reals = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=real_labels_out)
            label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=fake_labels_out)
            label_penalty_reals = tfutil.autosummary('Loss/label_penalty_reals', label_penalty_reals)
            label_penalty_fakes = tfutil.autosummary('Loss/label_penalty_fakes', label_penalty_fakes)

        loss += (1 * label_penalty_reals + 1*label_penalty_fakes) * cond_weight

    return loss
Beispiel #27
0
def G_wgan_acgan(G,
                 D,
                 opt,
                 training_set,
                 minibatch_size,
                 unlabeled_reals,
                 cond_weight=0.0):  # Weight of the conditioning term.
    '''
    Calculating the feature matching loss for the generator
    '''
    # get generated samples
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    # get random labels for the generated samples
    rand_gen_labels = training_set.get_random_labels_tf(minibatch_size)
    # use the generator to deconvolve the latents into images
    fake_images_out = G.get_output_for(latents,
                                       rand_gen_labels,
                                       is_training=True)

    # use the discriminator to get the features from the last convolution layer as well as the logits
    fake_logits_out, _, fake_features_out = fp32(
        D.get_output_for(fake_images_out, is_training=False))
    # Pass the unlabeled real data to the discriminator and grab the real features out from the last convolutional layer
    _, _, real_features_out = fp32(
        D.get_output_for(unlabeled_reals, is_training=False))

    # calculate feature-matching loss
    # mean squared error of fake and real features
    feat_diff = tf.math.reduce_mean(
        fake_features_out, axis=0) - tf.math.reduce_mean(real_features_out,
                                                         axis=0)
    loss = tf.math.reduce_mean(tf.math.square(feat_diff))

    loss = tfutil.autosummary('Loss/G_feat_match_loss', loss)

    # if D.output_shapes[1][1] > 0:
    #     with tf.name_scope('LabelPenalty'):
    #         # pass fake logits and labels to a softmax layer
    #         label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=rand_gen_labels, logits=fake_logits_out)
    #     loss += label_penalty_fakes * cond_weight
    # loss = tfutil.autosummary('Loss/G_feat_match_loss_post_LabelPenalty', loss)
    return loss
def D_wgangp_acgan(G, D, opt, training_set, minibatch_size, reals, labels,
    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.
    wgan_epsilon    = 0.001,    # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target     = 1.0,      # Target value for gradient magnitudes.
    cond_weight     = 1.0):     # Weight of the conditioning terms.

    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out, real_labels_out = fp32(D.get_output_for(reals, is_training=True))
    fake_scores_out, fake_labels_out = fp32(D.get_output_for(fake_images_out, is_training=True))
    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/fake_scores', fake_scores_out)
    loss = fake_scores_out - real_scores_out

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
        mixed_images_out = tfutil.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
        mixed_scores_out, mixed_labels_out = fp32(D.get_output_for(mixed_images_out, is_training=True))
        mixed_scores_out = tfutil.autosummary('Loss/mixed_scores', mixed_scores_out)
        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
        mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
        mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = tfutil.autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
    loss += epsilon_penalty * wgan_epsilon

    if D.output_shapes[1][1] > 0:
        with tf.name_scope('LabelPenalty'):
            label_penalty_reals = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=real_labels_out)
            label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=fake_labels_out)
            label_penalty_reals = tfutil.autosummary('Loss/label_penalty_reals', label_penalty_reals)
            label_penalty_fakes = tfutil.autosummary('Loss/label_penalty_fakes', label_penalty_fakes)
        loss += (label_penalty_reals + label_penalty_fakes) * cond_weight
    return loss
Beispiel #29
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs, E = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)

            E = tfutil.Network('E',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.E)

            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()
    E.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)
    E_opt = tfutil.Optimizer(name='TrainE',
                             learning_rate=lrate_in,
                             **config.E_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in),
                tf.assign(E_gpu.find_var('lod'), lod_in)
            ]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)
            with tf.name_scope('E_loss'), tf.control_dependencies(
                    lod_assign_ops):
                E_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=E_opt,
                    training_set=training_set,
                    reals=reals_gpu,
                    minibatch_size=minibatch_split,
                    **config.E_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
            E_opt.register_gradients(tf.reduce_mean(E_loss), E_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    E_train_op = E_opt.apply_updates()

    #sys.exit(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
                E_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op, E_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

                misc.save_all_res(training_set.shape[1],
                                  Gs,
                                  result_subdir,
                                  50,
                                  minibatch_size=sched.minibatch //
                                  config.num_gpus)

            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (G, D, Gs, E),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs, E),
                  os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Beispiel #30
0
def D_wgangp_acgan(
    G,
    D,
    opt,
    training_set,
    minibatch_size,
    reals,
    labels,
    unlabeled_reals,
    wgan_lambda=0.0,  # Weight for the gradient penalty term.
    wgan_epsilon=0.0,  # Weight for the epsilon term, \epsilon_{drift}.
    wgan_target=0.1,  # Target value for gradient magnitudes.
    cond_weight=0.0):  # Weight of the conditioning terms.

    # Generate latents and pass through the generator to decolvolve into fake images
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)

    # REALS
    output_before_softmax_lab, real_flogit_out, _ = fp32(
        D.get_output_for(reals, is_training=True))
    # UNLABELED REALS
    output_before_softmax_unl, _, _ = fp32(
        D.get_output_for(unlabeled_reals, is_training=True))
    # GENERATED
    output_before_softmax_fake, fake_flogit_out, _ = fp32(
        D.get_output_for(fake_images_out, is_training=True))

    # Direct port labeled loss from Tim Salimans et al. https://arxiv.org/pdf/1606.03498.pdf
    # no support for tensor indexing, so no work
    #simple_labels = tf.argmax(labels, axis=1)
    #z_exp_lab = tf.math.reduce_mean(tf.math.reduce_logsumexp(output_before_softmax_lab, axis=1))
    #l_lab = output_before_softmax_lab[tf.range(minibatch_size), simple_labels]
    #loss_lab = -tf.math.reduce_mean(l_lab) + tf.math.reduce_mean(z_exp_lab)

    train_err = tf.math.reduce_mean(
        tf.cast(
            tf.math.not_equal(
                tf.math.argmax(output_before_softmax_lab, axis=1),
                tf.math.argmax(labels, axis=1)), tf.float32))
    train_err = tfutil.autosummary('Loss/D_train_err', train_err)

    # labeled sample loss is equivalent to cross entropy w/ softmax (I think?)
    loss_lab = tf.math.reduce_sum(
        tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=labels, logits=output_before_softmax_lab))

    # Another implementation of Salimans code ported to TF (NOT WORKING the tf.gather is wrong)
    #l_lab = tf.gather(output_before_softmax_lab, tf.range(minibatch_size),labels)
    #loss_lab = -tf.math.reduce_sum(l_lab) + tf.math.reduce_sum(tf.math.reduce_sum(tf.math.reduce_logsumexp(output_before_softmax_lab)))

    # Direct port of unlabeled loss and fake loss. from Tim Salimans et al. https://arxiv.org/pdf/1606.03498.pdf
    # Code reference https://github.com/openai/improved-gan/blob/master/mnist_svhn_cifar10/train_cifar_feature_matching.py#L87
    #z_exp_unl = tf.math.reduce_mean(tf.math.reduce_logsumexp(output_before_softmax_unl, axis=1))
    loss_unl = -0.5*tf.math.reduce_mean(tf.math.reduce_logsumexp(output_before_softmax_unl, axis=1)) + \
               0.5*tf.math.reduce_mean(tf.math.softplus(tf.math.reduce_logsumexp(output_before_softmax_unl, axis=1)))
    loss_fake = 0.5 * tf.math.reduce_mean(
        tf.math.softplus(
            tf.math.reduce_logsumexp(output_before_softmax_fake, axis=1)))

    # Using autosummary for tensorboard
    loss_lab = tfutil.autosummary('Loss/D_loss_lab', loss_lab)
    loss_unl = tfutil.autosummary('Loss/D_loss_unl', loss_unl)
    loss_fake = tfutil.autosummary('Loss/D_loss_fake', loss_fake)

    # combine losses
    loss = loss_lab + loss_unl + loss_fake + (train_err * 0)

    loss = tfutil.autosummary('Loss/D_combined_loss', loss)

    # with tf.name_scope('GradientPenalty'):
    #     mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
    #     mixed_images_out = tfutil.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
    #     mixed_scores_out, mixed_labels_out, _ = fp32(D.get_output_for(mixed_images_out, is_training=True))
    #     mixed_scores_out = tfutil.autosummary('Loss/mixed_scores', mixed_scores_out)
    #     mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
    #     mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
    #     mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
    #     mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
    #     gradient_penalty = tf.square(mixed_norms - wgan_target)
    # loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    # with tf.name_scope('EpsilonPenalty'):
    #     epsilon_penalty = tfutil.autosummary('Loss/epsilon_penalty', tf.square(real_flogit_out))
    # loss += epsilon_penalty * wgan_epsilon

    # if D.output_shapes[1][1] > 0:
    #     with tf.name_scope('LabelPenalty'):
    #         label_penalty_reals = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=output_before_softmax_lab)
    #         label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=output_before_softmax_fake)
    #         label_penalty_reals = tfutil.autosummary('Loss/label_penalty_reals', label_penalty_reals)
    #         label_penalty_fakes = tfutil.autosummary('Loss/label_penalty_fakes', label_penalty_fakes)
    #     loss += (label_penalty_reals + label_penalty_fakes) * cond_weight

    # loss = tfutil.autosummary('Loss/D_combined_loss_post_penalties', loss)
    return loss
Beispiel #31
0
def D_wgangp_acgan(
        G,
        D,
        opt,
        training_set,
        minibatch_size,
        reals,
        labels,
        wgan_lambda=10.0,  # Weight for the gradient penalty term.
        wgan_epsilon=0.001,  # Weight for the epsilon term, \epsilon_{drift}.qishi
        wgan_target=1.0,  # Target value for gradient magnitudes.
        cond_weight=1.0,  # Weight of the conditioning terms.
        fingerprint_weight=20.0,  # Weight of the fingerprint terms.
):
    #print("#training_set.shape: ", training_set.shape) #training_set.shape:  [3, 128, 128]
    #print("#reals.shape: ", reals.shape)               #reals.shape:  (?, ?, ?, ?)

    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out, real_labels_out, real_features_out = fp32(
        D.get_output_for(reals, is_training=True))
    fake_scores_out, fake_labels_out, fake_features_out = fp32(
        D.get_output_for(fake_images_out, is_training=True))
    real_scores_out = tfutil.autosummary('Loss/real_scores', real_scores_out)
    fake_scores_out = tfutil.autosummary('Loss/fake_scores', fake_scores_out)
    loss = fake_scores_out - real_scores_out
    #print("#loss.shape: ", loss.shape) #loss.shape:  (?, 1)

    # Add Fingerprints loss
    with tf.name_scope('FingerprintPenalty'):
        #real_features_out = feature_extractor(image_batch=reals, image_shape=training_set.shape, batch_size=minibatch_size_np, extractor_dir=extractor_dir)
        #fake_features_out = feature_extractor(image_batch=fake_images_out, image_shape=training_set.shape, batch_size=minibatch_size_np, extractor_dir=extractor_dir)
        fingerprints_penalty = tf.reduce_mean(tf.abs(real_features_out -
                                                     fake_features_out),
                                              axis=1,
                                              keep_dims=False)
        fingerprints_penalty = tfutil.autosummary('Loss/fingerprints_scores',
                                                  fingerprints_penalty)
        #print("#fingerprints_penalty.shape: ", fingerprints_penalty.shape) #fingerprints_penalty.shape:  (minibatch_size,)
    fingerprints_penalty *= fingerprint_weight
    loss += fingerprints_penalty

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1],
                                           0.0,
                                           1.0,
                                           dtype=fake_images_out.dtype)
        mixed_images_out = tfutil.lerp(tf.cast(reals, fake_images_out.dtype),
                                       fake_images_out, mixing_factors)
        mixed_scores_out, mixed_labels_out, mixed_features_out = fp32(
            D.get_output_for(mixed_images_out, is_training=True))
        mixed_scores_out = tfutil.autosummary('Loss/mixed_scores',
                                              mixed_scores_out)
        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
        mixed_grads = opt.undo_loss_scaling(
            fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
        mixed_norms = tf.sqrt(
            tf.reduce_sum(tf.square(mixed_grads), axis=[1, 2, 3]))
        mixed_norms = tfutil.autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))

    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = tfutil.autosummary('Loss/epsilon_penalty',
                                             tf.square(real_scores_out))
    loss += epsilon_penalty * wgan_epsilon

    if D.output_shapes[1][1] > 0:
        with tf.name_scope('LabelPenalty'):
            label_penalty_reals = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=labels, logits=real_labels_out)
            label_penalty_fakes = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=labels, logits=fake_labels_out)
            label_penalty_reals = tfutil.autosummary(
                'Loss/label_penalty_reals', label_penalty_reals)
            label_penalty_fakes = tfutil.autosummary(
                'Loss/label_penalty_fakes', label_penalty_fakes)
        loss += (label_penalty_reals + label_penalty_fakes) * cond_weight
    return loss  #, fingerprints_penalty
def train_progressive_gan(
    G_smoothing             = 0.999,        # Exponential running average of generator weights.
    D_repeats               = 1,            # How many times the discriminator is trained per G iteration.
    minibatch_repeats       = 4,            # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod   = True,         # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 15000,        # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,        # Enable mirror augment?
    drange_net              = [-1,1],       # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 1,            # How often to export image snapshots?
    network_snapshot_ticks  = 10,           # How often to export network snapshots?
    save_tf_graph           = False,        # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,        # Include weight histograms in the tfevents file?
    resume_run_id           = None,         # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot         = None,         # Snapshot index to resume training from, None = autodetect.
    resume_kimg             = 0.0,          # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0):         # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G)
            D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels   = training_set.get_minibatch_tf()
        reals_split     = tf.split(reals, config.num_gpus)
        labels_split    = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
                cur_nimg += sched.minibatch
            tfutil.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (
                tfutil.autosummary('Progress/tick', cur_tick),
                tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                tfutil.autosummary('Progress/lod', sched.lod),
                tfutil.autosummary('Progress/minibatch', sched.minibatch),
                misc.format_time(tfutil.autosummary('Timing/total_sec', total_time)),
                tfutil.autosummary('Timing/sec_per_tick', tick_time),
                tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                tfutil.autosummary('Timing/maintenance_sec', maintenance_time)))
            tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()