Example #1
0
    def gen_training_step(real_x, real_x_p):
        hidden_z = dist_z.sample([batch_size, time_steps, z_width * z_height])
        hidden_z_p = dist_z.sample(
            [batch_size, time_steps, z_width * z_height])

        hidden_y = dist_y.sample([batch_size, y_dims])
        hidden_y_p = dist_y.sample([batch_size, y_dims])

        with tf.GradientTape() as gen_tape:
            fake_data = generator.call(hidden_z, hidden_y)
            fake_data_p = generator.call(hidden_z_p, hidden_y_p)

            h_fake = discriminator_h.call(fake_data)

            m_real = discriminator_m.call(real_x)
            m_fake = discriminator_m.call(fake_data)

            h_real_p = discriminator_h.call(real_x_p)
            h_fake_p = discriminator_h.call(fake_data_p)

            m_real_p = discriminator_m.call(real_data_p)

            loss2 = gan_utils.compute_mixed_sinkhorn_loss(
                real_data, fake_data, m_real, m_fake, h_fake, scaling_coef,
                sinkhorn_eps, sinkhorn_l, real_data_p, fake_data_p, m_real_p,
                h_real_p, h_fake_p)

            gen_loss = loss2
        # update generator parameters
        generator_grads = gen_tape.gradient(gen_loss,
                                            generator.trainable_variables)
        gen_optimiser.apply_gradients(
            zip(generator_grads, generator.trainable_variables))
        return loss2
Example #2
0
    def gen_training_step(real_data, real_data_p):
        # split real data to training inputs and predictions
        real_preds = real_data[:, time_steps // 2:, :]
        real_preds_p = real_data_p[:, time_steps // 2:, :]
        real_inputs = real_data[:, :time_steps // 2, :]
        real_inputs_p = real_data_p[:, :time_steps // 2, :]

        with tf.GradientTape() as gen_tape:
            fake_data = generator.call(real_inputs)
            fake_data_p = generator.call(real_inputs_p)

            h_fake = discriminator_h.call(fake_data)

            m_real = discriminator_m.call(real_preds)
            m_fake = discriminator_m.call(fake_data)

            h_real_p = discriminator_h.call(real_preds_p)
            h_fake_p = discriminator_h.call(fake_data_p)

            m_real_p = discriminator_m.call(real_preds_p)

            loss2 = gan_utils.compute_mixed_sinkhorn_loss(
                real_preds, fake_data, m_real, m_fake, h_fake, scaling_coef,
                sinkhorn_eps, sinkhorn_l, real_preds_p, fake_data_p, m_real_p,
                h_real_p, h_fake_p)
            gen_loss = loss2
        # update generator parameters
        generator_grads = gen_tape.gradient(gen_loss,
                                            generator.trainable_variables)
        gen_optimiser.apply_gradients(
            zip(generator_grads, generator.trainable_variables))
        return loss2
Example #3
0
    def disc_training_step(real_data, real_data_p):
        hidden_z = dist_z.sample([batch_size, time_steps, z_dims_t])
        hidden_z_p = dist_z.sample([batch_size, time_steps, z_dims_t])
        hidden_y = dist_y.sample([batch_size, y_dims])
        hidden_y_p = dist_y.sample([batch_size, y_dims])

        with tf.GradientTape(persistent=True) as disc_tape:
            fake_data = generator.call(hidden_z, hidden_y)
            fake_data_p = generator.call(hidden_z_p, hidden_y_p)

            h_fake = discriminator_h.call(fake_data)

            m_real = discriminator_m.call(real_data)
            m_fake = discriminator_m.call(fake_data)

            h_real_p = discriminator_h.call(real_data_p)
            h_fake_p = discriminator_h.call(fake_data_p)

            m_real_p = discriminator_m.call(real_data_p)

            loss1 = gan_utils.compute_mixed_sinkhorn_loss(
                real_data, fake_data, m_real, m_fake, h_fake, scaling_coef,
                sinkhorn_eps, sinkhorn_l, real_data_p, fake_data_p, m_real_p,
                h_real_p, h_fake_p)
            pm1 = gan_utils.scale_invariante_martingale_regularization(
                m_real, reg_penalty, scaling_coef)
            disc_loss = - loss1 + pm1
        # update discriminator parameters
        disch_grads, discm_grads = disc_tape.gradient(
            disc_loss, [discriminator_h.trainable_variables, discriminator_m.trainable_variables])
        dischm_optimiser.apply_gradients(zip(disch_grads, discriminator_h.trainable_variables))
        dischm_optimiser.apply_gradients(zip(discm_grads, discriminator_m.trainable_variables))
Example #4
0
    def disc_training_step(real_data, real_data_p):
        # split real data to training inputs and predictions
        real_preds = real_data[:, time_steps // 2:, :]
        real_preds_p = real_data_p[:, time_steps // 2:, :]
        real_inputs = real_data[:, :time_steps // 2, :]
        real_inputs_p = real_data_p[:, :time_steps // 2, :]

        with tf.GradientTape(persistent=True) as disc_tape:
            fake_data = generator.call(real_inputs)
            fake_data_p = generator.call(real_inputs_p)

            h_fake = discriminator_h.call(fake_data)

            m_real = discriminator_m.call(real_preds)
            m_fake = discriminator_m.call(fake_data)

            h_real_p = discriminator_h.call(real_preds_p)
            h_fake_p = discriminator_h.call(fake_data_p)

            m_real_p = discriminator_m.call(real_preds_p)

            loss1 = gan_utils.compute_mixed_sinkhorn_loss(
                real_preds, fake_data, m_real, m_fake, h_fake, scaling_coef,
                sinkhorn_eps, sinkhorn_l, real_preds_p, fake_data_p, m_real_p,
                h_real_p, h_fake_p)
            pm1 = gan_utils.scale_invariante_martingale_regularization(
                m_real, reg_penalty, scaling_coef)
            disc_loss = -loss1 + pm1
        # update discriminator parameters
        disch_grads, discm_grads = disc_tape.gradient(disc_loss, [
            discriminator_h.trainable_variables,
            discriminator_m.trainable_variables
        ])
        dischm_optimiser.apply_gradients(
            zip(disch_grads, discriminator_h.trainable_variables))
        dischm_optimiser.apply_gradients(
            zip(discm_grads, discriminator_m.trainable_variables))