예제 #1
0
def train_D(A, B, A2B, B2A):
    with tf.GradientTape() as t:
        A_d_logits = D_A(A, training=True)
        B2A_d_logits = D_A(B2A, training=True)
        B_d_logits = D_B(B, training=True)
        A2B_d_logits = D_B(A2B, training=True)

        A_d_loss, B2A_d_loss = d_loss_fn(A_d_logits, B2A_d_logits)
        B_d_loss, A2B_d_loss = d_loss_fn(B_d_logits, A2B_d_logits)
        D_A_gp = gan.gradient_penalty(functools.partial(D_A, training=True),
                                      A,
                                      B2A,
                                      mode=args.gradient_penalty_mode)
        D_B_gp = gan.gradient_penalty(functools.partial(D_B, training=True),
                                      B,
                                      A2B,
                                      mode=args.gradient_penalty_mode)

        D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + (
            D_A_gp + D_B_gp) * args.gradient_penalty_weight

    D_grad = t.gradient(D_loss,
                        D_A.trainable_variables + D_B.trainable_variables)
    D_optimizer.apply_gradients(
        zip(D_grad, D_A.trainable_variables + D_B.trainable_variables))

    return {
        'A_d_loss': A_d_loss + B2A_d_loss,
        'B_d_loss': B_d_loss + A2B_d_loss,
        'D_A_gp': D_A_gp,
        'D_B_gp': D_B_gp
    }
예제 #2
0
def train_D(x_real):
    with tf.GradientTape(persistent=True) as t:
        # Changed by K.C, the input signal of the Generator from the original Z noise to the real PET dataset
        x1_real = get_PET(x_real)
        x2_real = get_CT(x_real)


        x1_fake, x2_fake = G(get_Mask(x_real), training=True)

        x1_real_d_logit = D1(x1_real, training=True)
        x1_fake_d_logit = D1(x1_fake, training=True)
        x2_real_d_logit = D2(x2_real, training=True)
        x2_fake_d_logit = D2(x2_fake, training=True)
        x1_real_d_loss, x1_fake_d_loss = d_loss_fn(x1_real_d_logit, x1_fake_d_logit)
        x2_real_d_loss, x2_fake_d_loss = d_loss_fn(x2_real_d_logit, x2_fake_d_logit)

        gp1 = gan.gradient_penalty(functools.partial(D1, training=True), x1_real, x1_fake, mode=args.gradient_penalty_mode)
        gp2 = gan.gradient_penalty(functools.partial(D2, training=True), x2_real, x2_fake, mode=args.gradient_penalty_mode)

        D1_loss = (x1_real_d_loss + x1_fake_d_loss) + gp1 * args.gradient_penalty_weight
        D2_loss = (x2_real_d_loss + x2_fake_d_loss) + gp2 * args.gradient_penalty_weight

    D1_optimizer = D_optimizer
    D2_optimizer = D_optimizer
    D1_grad = t.gradient(D1_loss, D1.trainable_variables)
    D1_optimizer.apply_gradients(zip(D1_grad, D1.trainable_variables))
    D2_grad = t.gradient(D2_loss, D2.trainable_variables)
    D2_optimizer.apply_gradients(zip(D2_grad, D2.trainable_variables))

    return {'d1_loss': x1_real_d_loss + x1_fake_d_loss, 'gp1': gp1, 'd2_loss': x2_real_d_loss + x2_fake_d_loss, 'gp2': gp2}
예제 #3
0
def train_D(x_real):
    with tf.GradientTape() as t:
        z = tf.random.normal(shape=(args.batch_size, 1, 1, args.z_dim))
        x_fake = G(z, training=True)

        x_real_d_logit = D(x_real, training=True)
        x_fake_d_logit = D(x_fake, training=True)

        x_real_d_loss, x_fake_d_loss = d_loss_fn(x_real_d_logit,
                                                 x_fake_d_logit)
        gp = gan.gradient_penalty(functools.partial(D, training=True),
                                  x_real,
                                  x_fake,
                                  mode=args.gradient_penalty_mode)

        D_loss = (x_real_d_loss +
                  x_fake_d_loss) + gp * args.gradient_penalty_weight

    D_grad = t.gradient(D_loss, D.trainable_variables)
    D_optimizer.apply_gradients(zip(D_grad, D.trainable_variables))

    return {
        'd_loss': x_real_d_loss + x_fake_d_loss,
        'gp': gp,
        'D_loss': D_loss
    }
예제 #4
0
def train_D(A, B, A2B, B2A):
    with tf.GradientTape() as t:
        A_d_logits = D_A(A, training=True)  #Discriminate A real/fake
        B2A_d_logits = D_A(B2A, training=True)  #Discriminate B2A real/fake
        B_d_logits = D_B(B, training=True)  #Discriminate B real/fake
        A2B_d_logits = D_B(A2B, training=True)  #Discriminate A2B real/fake

        #Adversarial loss=============================================
        #Discriminator Loss for A (Least square GAN)
        A_d_loss, B2A_d_loss = d_loss_fn(A_d_logits, B2A_d_logits)
        #Discriminator Loss for B (Least Square GAN)
        B_d_loss, A2B_d_loss = d_loss_fn(B_d_logits, A2B_d_logits)

        #Interpolation Loss
        D_A_gp = gan.gradient_penalty(functools.partial(D_A, training=True),
                                      A,
                                      B2A,
                                      mode=args.gradient_penalty_mode)
        D_B_gp = gan.gradient_penalty(functools.partial(D_B, training=True),
                                      B,
                                      A2B,
                                      mode=args.gradient_penalty_mode)

        D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + (
            D_A_gp + D_B_gp) * args.gradient_penalty_weight

    D_grad = t.gradient(D_loss,
                        D_A.trainable_variables + D_B.trainable_variables)
    D_optimizer.apply_gradients(
        zip(D_grad, D_A.trainable_variables + D_B.trainable_variables))

    return {
        'A_d_loss': A_d_loss + B2A_d_loss,
        'B_d_loss': B_d_loss + A2B_d_loss,
        'D_A_gp': D_A_gp,
        'D_B_gp': D_B_gp
    }