コード例 #1
0
def train(train_datasetA, train_datasetB, epochs, lsgan=True, cyc_lambda=10):

    for epoch in range(epochs):

        start = time.time()

        with tf.GradientTape() as genA2B_tape, tf.GradientTape() as genB2A_tape, \
                tf.GradientTape() as discA_tape, tf.GradientTape() as discB_tape:

            try:
                # Next training minibatches, default size 1
                trainA = next(train_datasetA)
                trainB = next(train_datasetB)
            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break

            genA2B_output = genA2B(trainA, training=True)
            genB2A_output = genB2A(trainB, training=True)

            discA_real_output = discA(trainA, training=True)
            discB_real_output = discB(trainB, training=True)

            discA_fake_output = discA(genB2A_output, training=True)
            discB_fake_output = discB(genA2B_output, training=True)

            reconstructedA = genB2A(genA2B_output, training=True)
            reconstructedB = genA2B(genB2A_output, training=True)

            # Use history buffer of 50 for disc loss
            discA_loss = discriminator_loss(discA_real_output, discA_fake_output, lsgan=lsgan)
            discB_loss = discriminator_loss(discB_real_output, discB_fake_output, lsgan=lsgan)

            genA2B_loss = generator_loss(discB_fake_output, lsgan=lsgan) + \
                cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB,
                                       cyc_lambda=cyc_lambda)
            genB2A_loss = generator_loss(discA_fake_output, lsgan=lsgan) + \
                cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB,
                                       cyc_lambda=cyc_lambda)

        genA2B_gradients = genA2B_tape.gradient(genA2B_loss, genA2B.trainable_variables)
        genB2A_gradients = genB2A_tape.gradient(genB2A_loss, genB2A.trainable_variables)

        discA_gradients = discA_tape.gradient(discA_loss, discA.trainable_variables)
        discB_gradients = discB_tape.gradient(discB_loss, discB.trainable_variables)

        genA2B_optimizer.apply_gradients(zip(genA2B_gradients, genA2B.trainable_variables))
        genB2A_optimizer.apply_gradients(zip(genB2A_gradients, genB2A.trainable_variables))

        discA_optimizer.apply_gradients(zip(discA_gradients, discA.trainable_variables))
        discB_optimizer.apply_gradients(zip(discB_gradients, discB.trainable_variables))

        if epoch % 40 == 0:
            generate_images(trainA, trainB, genB2A_output, genA2B_output, epoch)

            print('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
コード例 #2
0
ファイル: dcgan.py プロジェクト: krooner/CS492-MLCV-CW2
    def train_step(images):
        noise = tf.random.normal([args.batsize, noise_dim])

        # D and G learns separately
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator(noise, training=True)

            real_output = discriminator(images, training=True)
            fake_output = discriminator(generated_images, training=True)

            gen_loss = generator_loss(fake_output)
            disc_loss = discriminator_loss(real_output, fake_output,
                                           args.alpha)

            gradients_of_generator = gen_tape.gradient(
                gen_loss, generator.trainable_variables)
            gradients_of_discriminator = disc_tape.gradient(
                disc_loss, discriminator.trainable_variables)

            generator_optimizer.apply_gradients(
                zip(gradients_of_generator, generator.trainable_variables))
            discriminator_optimizer.apply_gradients(
                zip(gradients_of_discriminator,
                    discriminator.trainable_variables))

        return gen_loss.numpy(), disc_loss.numpy()
コード例 #3
0
ファイル: train.py プロジェクト: kodamanbou/Pix2Pix
def train_step(input_image, target, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output],
                                              training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
            disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss,
                                            generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(
        disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(
        zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(
        zip(discriminator_gradients, discriminator.trainable_variables))

    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
        tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
        tf.summary.scalar('disc_loss', disc_loss, step=epoch)
コード例 #4
0
    def train_step(self, input_image, target):
        # def train_step(self, input_image, target, meta):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_output = self.generator(input_image, training=True)
            # gen_output = self.generator([input_image, meta], training = True)

            # disc_real_output = self.discriminator([input_image, meta, target], training = True)
            # disc_gen_output = self.discriminator([input_image, meta, gen_output], training = True)
            disc_real_output = self.discriminator([input_image, target],
                                                  training=True)
            disc_gen_output = self.discriminator([input_image, gen_output],
                                                 training=True)

            gen_total_loss, _, _ = generator_loss(disc_gen_output, gen_output,
                                                  target)
            disc_loss = discriminator_loss(disc_real_output, disc_gen_output)

        generator_gradients = gen_tape.gradient(
            gen_total_loss, self.generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(
            disc_loss, self.discriminator.trainable_variables)

        tf.print(
            'XX:XX:XX     INFO              trainer > Generator Loss:     ',
            gen_total_loss)
        tf.print(
            'XX:XX:XX     INFO              trainer > Discriminator Loss: ',
            disc_loss)

        self.generator_optimizer.apply_gradients(
            zip(generator_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients,
                self.discriminator.trainable_variables))
コード例 #5
0
    def train_step(images):
        noise = tf.random.normal([BATCH_SIZE, noise_dim])

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator(noise, training=True)

            real_output = discriminator(images, training=True)
            fake_output = discriminator(generated_images, training=True)

            gen_loss = generator_loss(fake_output)
            disc_loss = discriminator_loss(real_output, fake_output)

        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
コード例 #6
0
    def train_step(input_image, target):
        '''
    Perform one training step

    Args:
      input_image   : Input image
      target        : Output image (ground thruth)

    Returns:
      gen_loss    : Generator loss
      disc_loss   : Dicriminator loss

    '''
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

            # Compute the Generator output
            gen_output = generator(input_image, training=True)

            # Compute the Discriminator output for real and generated inputs
            disc_real_output = discriminator([input_image, target],
                                             training=True)
            disc_generated_output = discriminator([input_image, gen_output],
                                                  training=True)

            # Computes the Generator and Discriminator losses
            gen_loss = generator_loss(disc_generated_output, gen_output,
                                      target)
            disc_loss = discriminator_loss(disc_real_output,
                                           disc_generated_output)

        # Apply Gradient Descent
        generator_gradients = gen_tape.gradient(gen_loss,
                                                generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(
            disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(
            zip(generator_gradients, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, discriminator.trainable_variables))

        return gen_loss, disc_loss, gen_output
コード例 #7
0
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output],
                                              training=True)

        gen_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_loss,
                                            generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(
        disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(
        zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(
        zip(discriminator_gradients, discriminator.trainable_variables))
コード例 #8
0
ファイル: train.py プロジェクト: zhuxiayin/PictureColoring
def train_step(input_data, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        gen_output = gen(input_data)

        dis_real_output = dis([input_data, target])
        dis_gene_output = dis([input_data, gen_output])

        tot_gen_loss, gen_loss, gen_l1_loss = model.generator_loss(dis_gene_output, gen_output, target)
        tot_dis_loss = model.discriminator_loss(dis_real_output, dis_gene_output)

    gen_gradients = gen_tape.gradient(tot_gen_loss, gen.trainable_variables)
    dis_gradients = dis_tape.gradient(tot_dis_loss, dis.trainable_variables)

    generator_optimizer.apply_gradients(
        zip(gen_gradients, gen.trainable_variables)
    )

    discriminator_optimizer.apply_gradients(
        zip(dis_gradients, dis.trainable_variables)
    )

    return tot_gen_loss, gen_loss, gen_l1_loss, tot_dis_loss
コード例 #9
0
def train_gan(D_net,
              G_net,
              D_optimizer,
              G_optimizer,
              discriminator_loss,
              generator_loss,
              noise_size=96,
              num_epochs=10):
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            bs = x.shape[0]
            # 判别网络
            real_data = Variable(x)  # 真实数据
            logits_real = D_net(real_data)  # 判别网络得分

            sample_noise = (torch.rand(bs, noise_size) -
                            0.5) / 0.5  # -1 ~ 1 的均匀分布
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据
            logits_fake = D_net(fake_images)  # 判别网络得分

            d_total_error = discriminator_loss(logits_real,
                                               logits_fake)  # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()  # 优化判别网络

            # 生成网络
            g_fake_seed = Variable(sample_noise).cuda()
            fake_images = G_net(g_fake_seed)  # 生成的假的数据

            gen_logits_fake = D_net(fake_images)
            g_error = generator_loss(gen_logits_fake)  # 生成网络的 loss
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化生成网络
コード例 #10
0
def main():
    parser = argparse.ArgumentParser(description='Train Blending GAN')
    parser.add_argument('--nef',
                        type=int,
                        default=64,
                        help='number of base filters in encoder')
    parser.add_argument('--ngf',
                        type=int,
                        default=64,
                        help='number of base filters in decoder')
    parser.add_argument('--nc',
                        type=int,
                        default=3,
                        help='number of output channels in decoder')
    parser.add_argument('--nBottleneck',
                        type=int,
                        default=4000,
                        help='number of output channels in encoder')
    parser.add_argument('--ndf',
                        type=int,
                        default=64,
                        help='number of base filters in D')

    parser.add_argument('--lr_d',
                        type=float,
                        default=0.0002,
                        help='Learning rate for Critic, default=0.0002')
    parser.add_argument('--lr_g',
                        type=float,
                        default=0.002,
                        help='Learning rate for Generator, default=0.002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='Beta for Adam, default=0.5')
    parser.add_argument('--l2_weight',
                        type=float,
                        default=0.99,
                        help='Weight for l2 loss, default=0.999')
    parser.add_argument('--train_steps',
                        default=float("58000"),
                        help='Max amount of training cycles')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='Input batch size')

    parser.add_argument('--data_root',
                        default='DataBase/TransientAttributes/cropped_images',
                        help='Path to dataset')
    parser.add_argument('--train_data_root',
                        default='DataBase/TransientAttributes/train.tfrecords',
                        help='Path to train dataset')
    parser.add_argument('--val_data_root',
                        default='DataBase/TransientAttributes/val.tfrecords',
                        help='Path to val dataset')
    parser.add_argument(
        '--image_size',
        type=int,
        default=64,
        help='The height / width of the network\'s input image')

    parser.add_argument(
        '--d_iters',
        type=int,
        default=5,
        help='# of discriminator iters per each generator iter')
    parser.add_argument('--clamp_lower',
                        type=float,
                        default=-0.01,
                        help='Lower bound for weight clipping')
    parser.add_argument('--clamp_upper',
                        type=float,
                        default=0.01,
                        help='Upper bound for weight clipping')

    parser.add_argument('--experiment',
                        default='blending_gan',
                        help='Where to store samples and models')
    parser.add_argument('--save_folder',
                        default='GP-GAN_training',
                        help='location to save')
    parser.add_argument('--tboard_save_dir',
                        default='tensorboard',
                        help='location to save tboard records')

    parser.add_argument('--val_freq',
                        type=int,
                        default=500,
                        help='frequency of validation')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=500,
                        help='Interval of snapshot (steps)')

    parser.add_argument('--weights_path',
                        type=str,
                        default=None,
                        help='path to checkpoint')

    args = parser.parse_args()

    print('Input arguments:')
    for key, value in vars(args).items():
        print('\t{}: {}'.format(key, value))
    print('')

    # Set up generator & discriminator
    print('Create & Init models ...')
    print('\tInit Generator network ...')
    generator = EncoderDecoder(encoder_filters=args.nef,
                               encoded_dims=args.nBottleneck,
                               output_channels=args.nc,
                               decoder_filters=args.ngf,
                               is_training=True,
                               image_size=args.image_size,
                               skip=False,
                               scope_name='generator')  #, conv_init=init_conv,

    generator_val = EncoderDecoder(encoder_filters=args.nef,
                                   encoded_dims=args.nBottleneck,
                                   output_channels=args.nc,
                                   decoder_filters=args.ngf,
                                   is_training=False,
                                   image_size=args.image_size,
                                   skip=False,
                                   scope_name='generator')

    print('\tInit Discriminator network ...')
    discriminator = DCGAN_D(image_size=args.image_size,
                            encoded_dims=1,
                            filters=args.ndf,
                            is_training=True,
                            scope_name='discriminator'
                            )  #, conv_init=init_conv, bn_init=init_bn)  # D

    discriminator_val = DCGAN_D(image_size=args.image_size,
                                encoded_dims=1,
                                filters=args.ndf,
                                is_training=False,
                                scope_name='discriminator')

    # Set up training graph
    with tf.device('/gpu:0'):

        train_dataset = DataFeeder(tfrecords_path=args.train_data_root,
                                   dataset_flag='train')
        composed_image, real_image = train_dataset.inputs(
            batch_size=args.batch_size, name='train_dataset')
        shape = composed_image.get_shape().as_list()
        composed_image.set_shape(
            [shape[0], args.image_size, args.image_size, shape[3]])
        real_image.set_shape(
            [shape[0], args.image_size, args.image_size, shape[3]])

        validation_dataset = DataFeeder(tfrecords_path=args.val_data_root,
                                        dataset_flag='val')
        composed_image_val, real_image_val = validation_dataset.inputs(
            batch_size=args.batch_size, name='val_dataset')
        composed_image_val.set_shape(
            [shape[0], args.image_size, args.image_size, shape[3]])
        real_image_val.set_shape(
            [shape[0], args.image_size, args.image_size, shape[3]])

        # Compute losses:

        # Train tensors
        fake = generator(composed_image)
        prob_disc_real = discriminator.encode(real_image)
        prob_disc_fake = discriminator.encode(fake)

        # Validation tensors
        fake_val = generator_val(composed_image)
        prob_disc_real_val = discriminator_val.encode(real_image)
        prob_disc_fake_val = discriminator_val.encode(fake)

        # Calculate losses
        gen_loss, l2_comp, disc_comp, fake_image_train = l2_generator_loss(
            fake=fake,
            target=real_image,
            prob_disc_fake=prob_disc_fake,
            l2_weight=args.l2_weight)

        disc_loss = discriminator_loss(prob_disc_real=prob_disc_real,
                                       prob_disc_fake=prob_disc_fake)

        gen_loss_val, _, _, fake_image_val = l2_generator_loss(
            fake=fake_val,
            target=real_image,
            prob_disc_fake=prob_disc_fake_val,
            l2_weight=args.l2_weight)

        disc_loss_val = discriminator_loss(prob_disc_real=prob_disc_real_val,
                                           prob_disc_fake=prob_disc_fake_val)

        # Set optimizers
        global_step = tf.Variable(0, name='global_step', trainable=False)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        with tf.control_dependencies(update_ops):

            discriminator_variables = [
                v for v in tf.trainable_variables()
                if v.name.startswith("discriminator")
            ]
            generator_variables = [
                v for v in tf.trainable_variables()
                if v.name.startswith("generator")
            ]

            optimizer_gen = tf.train.AdamOptimizer(
                learning_rate=args.lr_g,
                beta1=args.beta1).minimize(loss=gen_loss,
                                           global_step=global_step,
                                           var_list=generator_variables)

            optimizer_disc = tf.train.AdamOptimizer(
                learning_rate=args.lr_d,
                beta1=args.beta1).minimize(loss=disc_loss,
                                           global_step=global_step,
                                           var_list=discriminator_variables)

            with tf.name_scope("clip_weights"):
                clip_discriminator_var_op = [
                    var.assign(
                        tf.clip_by_value(var, args.clamp_lower,
                                         args.clamp_upper))
                    for var in discriminator_variables
                ]

    # Set summaries for Tensorboard

    model_save_dir = os.path.join(args.save_folder, args.experiment)

    tboard_save_dir = os.path.join(model_save_dir, args.tboard_save_dir)
    os.makedirs(tboard_save_dir, exist_ok=True)
    sum_gen_train = tf.summary.scalar(name='train_gen_loss', tensor=gen_loss)
    sum_gen_disc_comp = tf.summary.scalar(name='train_gen_disc_component',
                                          tensor=disc_comp)
    sum_gen_l2_comp = tf.summary.scalar(name='train_gen_l2_component',
                                        tensor=l2_comp)

    sum_gen_val = tf.summary.scalar(name='val_gen_loss',
                                    tensor=gen_loss_val,
                                    collections='')
    sum_disc_train = tf.summary.scalar(name='train_disc_loss',
                                       tensor=disc_loss)
    sum_disc_val = tf.summary.scalar(name='val_disc_loss',
                                     tensor=disc_loss_val)
    sum_fake_image_train = tf.summary.image(name='train_image_generated',
                                            tensor=fake_image_train)
    sum_fake_image_val = tf.summary.image(name='val_image_generated',
                                          tensor=fake_image_val)
    sum_disc_real = tf.summary.scalar(name='train_disc_value_real',
                                      tensor=tf.reduce_mean(prob_disc_real))
    sum_disc_fake = tf.summary.scalar(name='train_disc_value_fake',
                                      tensor=tf.reduce_mean(prob_disc_fake))

    sum_composed = tf.summary.image(name='composed', tensor=composed_image)
    sum_real = tf.summary.image(name='real', tensor=real_image)

    train_merge = tf.summary.merge([
        sum_gen_train, sum_fake_image_train, sum_disc_train, sum_composed,
        sum_real, sum_gen_disc_comp, sum_gen_l2_comp, sum_disc_real,
        sum_disc_fake
    ])

    # Set saver configuration

    loader = tf.train.Saver()
    saver = tf.train.Saver()
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'GP-GAN_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = os.path.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=sess_config)

    # Write graph to tensorboard
    summary_writer = tf.summary.FileWriter(tboard_save_dir)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters

    with sess.as_default():
        step = 0
        cycle = 0

        if args.weights_path is None:
            print('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            print('Restore model from {:s}'.format(args.weights_path))
            loader.restore(sess=sess, save_path=args.weights_path)

            step_cycle = args.weights_path.split('ckpt-')[-1]
            step, cycle = decode_step_cycle(step_cycle)

        gen_train_loss = '?'
        while cycle <= args.train_steps:

            # (1) Update discriminator network
            # train the discriminator Diters times

            if cycle < 25 or cycle % 500 == 0:
                Diters = 100

            else:
                Diters = args.d_iters

            for _ in range(Diters):
                # enforce Lipschitz constraint
                sess.run(clip_discriminator_var_op)

                _, disc_train_loss = sess.run([optimizer_disc, disc_loss])
                print('Step: ' + str(step) + ' Cycle: ' + str(cycle) +
                      ' Train discriminator loss: ' + str(disc_train_loss) +
                      ' Train generator loss: ' + str(gen_train_loss))

                step += 1

            # (2) Update generator network

            _, gen_train_loss, train_merge_value = sess.run(
                [optimizer_gen, gen_loss, train_merge])
            summary_writer.add_summary(summary=train_merge_value,
                                       global_step=cycle)

            if cycle != 0 and cycle % args.val_freq == 0:
                _, disc_val_loss, gen_val_value, fake_image_val_value = sess.run(
                    [
                        optimizer_disc, gen_loss_val, sum_gen_val,
                        sum_fake_image_val
                    ])
                _, gen_val_loss, disc_val_value = sess.run(
                    [optimizer_gen, disc_loss_val, sum_disc_val])
                print('Step: ' + str(step) + ' Cycle: ' + str(cycle) +
                      ' Val discriminator loss: ' + str(disc_val_loss) +
                      ' Val generator loss: ' + str(gen_val_loss))
                summary_writer.add_summary(summary=gen_val_value,
                                           global_step=cycle)
                summary_writer.add_summary(summary=disc_val_value,
                                           global_step=cycle)
                summary_writer.add_summary(summary=fake_image_val_value,
                                           global_step=cycle)

            if cycle != 0 and cycle % args.snapshot_interval == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=encode_step_cycle(step, cycle))
            cycle += 1
コード例 #11
0
def train_step(real_x, real_y, G_YtoX, G_XtoY, D_X, D_Y, G_YtoX_optimizer,
               G_XtoY_optimizer, D_X_optimizer, D_Y_optimizer, opt):
    # persistent is set to True because the tape is used more than once to calculate the gradients.
    with tf.GradientTape(persistent=True) as tape:
        # Generator G_XtoY translates X -> Y
        # Generator G_YtoX translates Y -> X.

        fake_y = G_XtoY(real_x, training=True)
        cycled_x = G_YtoX(fake_y, training=True)

        fake_x = G_YtoX(real_y, training=True)
        cycled_y = G_XtoY(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = G_XtoY(real_x, training=True)
        same_y = G_XtoY(real_y, training=True)

        disc_real_x = D_X(real_x, training=True)
        disc_real_y = D_Y(real_y, training=True)

        disc_fake_x = D_X(fake_x, training=True)
        disc_fake_y = D_Y(fake_y, training=True)

        # calculate the loss
        G_XtoY_loss = model.generator_loss(disc_fake_y)
        G_YtoX_loss = model.generator_loss(disc_fake_x)

        if opt["use_cycle_consistency_loss"]:
            total_cycle_loss = model.calc_cycle_loss(
                real_x, cycled_x) + model.calc_cycle_loss(real_y, cycled_y)
        else:
            total_cycle_loss = 0

        # Total generator loss = adversarial loss + cycle loss
        total_G_XtoY_loss = G_XtoY_loss + total_cycle_loss + model.identity_loss(
            real_y, same_y)
        total_G_YtoX_loss = G_YtoX_loss + total_cycle_loss + model.identity_loss(
            real_x, same_x)

        disc_x_loss, update_D_X = model.discriminator_loss(
            disc_real_x, disc_fake_x)
        disc_y_loss, update_D_Y = model.discriminator_loss(
            disc_real_y, disc_fake_y)

        # total loss to be shown
        total_disc_loss = (disc_x_loss + disc_y_loss) / 2
        total_gen_loss = (total_G_XtoY_loss + total_G_YtoX_loss) / 2

    # Calculate the gradients for generator and discriminator

    G_XtoY_gradients = tape.gradient(total_G_XtoY_loss,
                                     G_XtoY.trainable_variables)
    G_YtoX_gradients = tape.gradient(total_G_YtoX_loss,
                                     G_YtoX.trainable_variables)
    if update_D_X:
        D_X_gradients = tape.gradient(disc_x_loss, D_X.trainable_variables)
    if update_D_Y:
        D_Y_gradients = tape.gradient(disc_y_loss, D_Y.trainable_variables)

    # Apply the gradients to the optimizer

    G_XtoY_optimizer.apply_gradients(
        zip(G_XtoY_gradients, G_XtoY.trainable_variables))

    G_YtoX_optimizer.apply_gradients(
        zip(G_YtoX_gradients, G_YtoX.trainable_variables))
    if update_D_X:
        D_X_optimizer.apply_gradients(
            zip(D_X_gradients, D_X.trainable_variables))
    if update_D_Y:
        D_Y_optimizer.apply_gradients(
            zip(D_Y_gradients, D_Y.trainable_variables))

    return total_disc_loss, total_gen_loss