def train_van_step(x_real, y_real): gen.train() dis.train() enc.train() if n_dim > 0: padding = tf.zeros((y_real.shape[0], n_dim)) y_real_pad = tf.concat((y_real, padding), axis=-1) else: y_real_pad = y_real # Alternate discriminator step and generator step with tf.GradientTape(persistent=False) as tape: # Generate z_fake = datasets.paired_randn(batch_size, z_dim, masks) z_fake = z_fake + y_real_pad x_fake = gen(z_fake) # Discriminate logits_fake = dis(x_fake, y_real) gen_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) with tf.GradientTape(persistent=True) as tape: # Generate z_fake = datasets.paired_randn(batch_size, z_dim, masks) z_fake = z_fake + y_real_pad x_fake = tf.stop_gradient(gen(z_fake)) # Discriminate x = tf.concat((x_real, x_fake), 0) y = tf.concat((y_real, y_real), 0) logits = dis(x, y) # Encode p_z = enc(x_fake) dis_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)
def train_gen_step(x1_real, x2_real, y_real): gen.train() dis.train() enc.train() # Alternate discriminator step and generator step with tf.GradientTape(persistent=True) as tape: # Generate z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks) x1_fake = tf.stop_gradient(gen(z1)) x2_fake = tf.stop_gradient(gen(z2)) # Discriminate x1 = tf.concat((x1_real, x1_fake), 0) x2 = tf.concat((x2_real, x2_fake), 0) y = tf.concat((y_real, y_fake), 0) logits = dis(x1, x2, y) # Encode p_z = enc(x1_fake) dis_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z1[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) with tf.GradientTape(persistent=False) as tape: # Generate z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks) x1_fake = gen(z1) x2_fake = gen(z2) # Discriminate logits_fake = dis(x1_fake, x2_fake, y_fake) gen_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)