Example #1
0
def build_graph(args, a_r, b_r, a2b_s, b2a_s):
    with tf.device('/gpu:{}'.format(args.gpu)):
        a2b = g_net(a_r, 'a2b')
        b2a = g_net(b_r, 'b2a')
        a2b2a = g_net(a2b, 'b2a', reuse=True)
        b2a2b = g_net(b2a, 'a2b', reuse=True)
        cvt = (a2b, b2a, a2b2a, b2a2b)

        a_d = d_net(a_r, 'a')
        b2a_d = d_net(b2a, 'a', reuse=True)
        b2a_s_d = d_net(b2a_s, 'a', reuse=True)

        b_d = d_net(b_r, 'b')
        a2b_d = d_net(a2b, 'b', reuse=True)
        a2b_s_d = d_net(a2b_s, 'b', reuse=True)

        g_loss_a2b = tf.identity(ops.l2_loss(a2b_d, tf.ones_like(a2b_d)),
                                 name='g_loss_a2b')
        g_loss_b2a = tf.identity(ops.l2_loss(b2a_d, tf.ones_like(b2a_d)),
                                 name='g_loss_b2a')
        cyc_loss_a = tf.identity(ops.l1_loss(a_r, a2b2a) * 10.0,
                                 name='cyc_loss_a')
        cyc_loss_b = tf.identity(ops.l1_loss(b_r, b2a2b) * 10.0,
                                 name='cyc_loss_b')
        g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b

        d_loss_a_r = ops.l2_loss(a_d, tf.ones_like(a_d))
        d_loss_b2a_s = ops.l2_loss(b2a_s_d, tf.zeros_like(b2a_s_d))
        d_loss_a = tf.identity((d_loss_a_r + d_loss_b2a_s) / 2.,
                               name='d_loss_a')

        d_loss_b_r = ops.l2_loss(b_d, tf.ones_like(b_d))
        d_loss_a2b_s = ops.l2_loss(a2b_s_d, tf.zeros_like(a2b_s_d))
        d_loss_b = tf.identity((d_loss_b_r + d_loss_a2b_s) / 2.,
                               name='d_loss_b')

        g_sum = ops.summary_tensors(
            [g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b])
        d_sum_a = ops.summary(d_loss_a)
        d_sum_b = ops.summary(d_loss_b)
        sum_ = (g_sum, d_sum_a, d_sum_b)

        all_var = tf.trainable_variables()
        g_var = [
            var for var in all_var
            if 'a2b_g' in var.name or 'b2a_g' in var.name
        ]
        d_a_var = [var for var in all_var if 'a_d' in var.name]
        d_b_var = [var for var in all_var if 'b_d' in var.name]

        g_tr_op = tf.train.AdamOptimizer(args.lr, beta1=args.beta1).minimize(
            g_loss, var_list=g_var)
        d_tr_op_a = tf.train.AdamOptimizer(args.lr, beta1=args.beta1).minimize(
            d_loss_a, var_list=d_a_var)
        d_tr_op_b = tf.train.AdamOptimizer(args.lr, beta1=args.beta1).minimize(
            d_loss_b, var_list=d_b_var)
        tr_op = (g_tr_op, d_tr_op_a, d_tr_op_b)
    return cvt, sum_, tr_op
Example #2
0
def Main():
    real_img = tf.placeholder("float", [BATCH_SIZE, IMG_SIZE, IMG_SIZE, 3])
    z = tf.placeholder("float", [BATCH_SIZE, h])
    G = generator("generator")
    D = discriminator("discriminator")
    k_t = tf.get_variable("k", initializer=[0.])
    fake_img = G(z, IMG_SIZE, n)
    real_logits = D(real_img, IMG_SIZE, n, h)
    fake_logits = D(fake_img, IMG_SIZE, n, h)
    real_loss = l1_loss(real_img, real_logits)
    fake_loss = l1_loss(fake_img, fake_logits)
    D_loss = real_loss - k_t * fake_loss
    G_loss = fake_loss
    M_global = real_loss + tf.abs(GAMMA * real_loss - fake_loss)
    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.inverse_time_decay(LEARNING_RATE, global_step,
                                                5000, 0.5)
    Opt_D = tf.train.AdamOptimizer(learning_rate).minimize(
        D_loss, var_list=D.var(), global_step=global_step)
    Opt_G = tf.train.AdamOptimizer(learning_rate).minimize(G_loss,
                                                           var_list=G.var())
    with tf.control_dependencies([Opt_D, Opt_G]):
        clip = tf.clip_by_value(k_t + LAMBDA * (GAMMA * real_loss - fake_loss),
                                0, 1)
        update_k = tf.assign(k_t, clip)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        facedata = sio.loadmat("../TrainingSet/facedata.mat")["data"]
        saver = tf.train.Saver()
        # saver.restore(sess, "./save_para/.\\model.ckpt")
        for epoch in range(200):
            for i in range(facedata.shape[0] // BATCH_SIZE - 1):
                batch = facedata[i * BATCH_SIZE:i * BATCH_SIZE +
                                 BATCH_SIZE, :, :, :] / 127.5 - 1.0
                z0 = np.random.uniform(0, 1, [BATCH_SIZE, h])
                sess.run(update_k, feed_dict={real_img: batch, z: z0})
                if i % 100 == 0:
                    [dloss, gloss, Mglobal, fakeimg, step,
                     lr] = sess.run([
                         D_loss, G_loss, M_global, fake_img, global_step,
                         learning_rate
                     ],
                                    feed_dict={
                                        real_img: batch,
                                        z: z0
                                    })
                    print(
                        "step: %d, d_loss: %f, g_loss: %f, M_global: %f, Learning_rate: %f"
                        % (step, dloss, gloss, Mglobal, lr))
                    Image.fromarray(np.uint8(
                        127.5 * (fakeimg[0, :, :, :] + 1))).save("./Results/" +
                                                                 str(step) +
                                                                 ".jpg")
            saver.save(sess, "./save_para/model.ckpt")
Example #3
0
def train():
    args = parse_args()
    if not os.path.exists(args.snapshot_dir): 
        os.makedirs(args.snapshot_dir)
    if not os.path.exists(args.out_dir): 
        os.makedirs(args.out_dir)

    x_datalists, y_datalists = make_train_data_list(args.x_train_data_path, args.y_train_data_path) 
    tf.set_random_seed(args.random_seed) 
    x_img = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size,3],name='x_img') 
    y_img = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size,3],name='y_img') 

    fake_y = Encoder_Decoder(image=x_img, reuse=False, name='Encoder_Decoder') 
    dy_fake = discriminator(image=fake_y, reuse=False, name='discriminator_fake') 
    dx_real = discriminator(image=x_img, reuse=False, name='discriminator_real1') 
    dy_real = discriminator(image=y_img, reuse=False, name='discriminator_real2')

    encoder_loss = gan_loss(dy_fake, tf.ones_like(dy_fake)) + args.lamda*l1_loss(y_img, fake_y)
    dis_loss = gan_loss(dy_fake, tf.zeros_like(dy_fake)) + gan_loss(dx_real, tf.ones_like(dx_real)) + gan_loss(dy_real, tf.ones_like(dx_real))
    
    dis_loss = args.coefficient * dis_loss + (1 - args.coefficient) * args.svalue

    encoder_loss_sum = tf.summary.scalar("encoder_loss", encoder_loss)
    discriminator_sum = tf.summary.scalar("dis_loss", dis_loss)

    summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) 

    g_vars = [v for v in tf.trainable_variables() if 'Encoder_Decoder' in v.name]
    d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name]

    lr = tf.placeholder(tf.float32, None, name='learning_rate')
    d_optim = tf.train.AdamOptimizer(lr, beta1=args.beta1) 
    e_optim = tf.train.AdamOptimizer(lr, beta1=args.beta1) 
    d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) 
    d_train = d_optim.apply_gradients(d_grads_and_vars) 
    e_grads_and_vars = e_optim.compute_gradients(encoder_loss, var_list=g_vars) 
    e_train = e_optim.apply_gradients(e_grads_and_vars) 

    train_op = tf.group(d_train, e_train)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=100) 
        counter = 0
        number = 0
        ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)  
        if ckpt and ckpt.model_checkpoint_path:        
            saver.restore(sess,ckpt.model_checkpoint_path)       
            print("Model restored...")    
        else:
            print('No Model')

        for epoch in range(args.epoch): 
           
            if number < 50000:
                if epoch < 50:
                    lrate = args.base_lr if epoch < args.epoch_step else args.base_lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 
                    for step in range(len(x_datalists)): 
                        counter += 1
                        x_image_resize, y_image_resize = TrainImageReader(x_datalists, y_datalists, step, args.image_size) 
                        batch_x_image = np.expand_dims(np.array(x_image_resize).astype(np.float32), axis = 0) 
                        batch_y_image = np.expand_dims(np.array(y_image_resize).astype(np.float32), axis = 0) 
                        feed_dict = { lr : lrate, x_img : batch_x_image, y_img : batch_y_image} 
                        encoder_loss_value, dis_loss_value, _ = sess.run([encoder_loss, dis_loss, train_op], feed_dict=feed_dict) 
                        if counter % args.save_pred_every == 0: 
                            save(saver, sess, args.snapshot_dir, counter)
                        if counter % args.summary_pred_every == 0: 
                            encoder_loss_sum_value, discriminator_sum_value = sess.run([encoder_loss_sum, discriminator_sum], feed_dict=feed_dict)
                            summary_writer.add_summary(encoder_loss_sum_value, counter)
                            summary_writer.add_summary(discriminator_sum_value, counter)
                        if counter % args.write_pred_every == 0:  
                            fake_y_value= sess.run(fake_y, feed_dict=feed_dict)
                            write_image = get_write_picture(x_image_resize, y_image_resize, fake_y_value)
                            write_image_name = args.out_dir + "/out"+ str(counter) + ".png" 
                            cv2.imwrite(write_image_name, write_image) 
                        if counter % args.human_number == 0:
                            for continued in range(args.duration):
                                args.svalue = random.uniform(0, 0.1)
                                args.coefficient = 0.3
                            number += 1
                        print('epoch {:d} step {:d} \t encoder_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, encoder_loss_value, dis_loss_value))

                elif epoch > 49 and epoch < 80 :
                    lrate = args.base_lr if epoch < args.epoch_step else args.base_lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 
                    for step in range(len(x_datalists)):
                        counter += 1
                        x_image_resize, y_image_resize = TrainImageReader(x_datalists, y_datalists, step, args.image_size) 
                        batch_x_image = np.expand_dims(np.array(x_image_resize).astype(np.float32), axis = 0) 
                        batch_y_image = np.expand_dims(np.array(y_image_resize).astype(np.float32), axis = 0) 
                        feed_dict = { lr : lrate, x_img : batch_x_image, y_img : batch_y_image} 
                        encoder_loss_value, dis_loss_value, _ = sess.run([encoder_loss, dis_loss, train_op], feed_dict=feed_dict) 
                        if counter % args.save_pred_every == 0: 
                            save(saver, sess, args.snapshot_dir, counter)
                        if counter % args.summary_pred_every == 0: 
                            encoder_loss_sum_value, discriminator_sum_value = sess.run([encoder_loss_sum, discriminator_sum], feed_dict=feed_dict)
                            summary_writer.add_summary(encoder_loss_sum_value, counter)
                            summary_writer.add_summary(discriminator_sum_value, counter)
                        if counter % args.write_pred_every == 0:  
                            fake_y_value= sess.run(fake_y, feed_dict=feed_dict)
                            write_image = get_write_picture(x_image_resize, y_image_resize, fake_y_value)
                            write_image_name = args.out_dir + "/out"+ str(counter) + ".png" 
                            cv2.imwrite(write_image_name, write_image)
                        if counter % args.human_number == 0:
                            for continued in range(args.duration):
                                args.svalue = random.uniform(0.2, 0.5)
                                args.coefficient = 0.3
                            number += 1
                        print('epoch {:d} step {:d} \t encoder_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, encoder_loss_value, dis_loss_value))

                else:
                    lrate = args.base_lr if epoch < args.epoch_step else args.base_lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 
                    for step in range(len(x_datalists)): 
                        counter += 1
                        x_image_resize, y_image_resize = TrainImageReader(x_datalists, y_datalists, step, args.image_size) 
                        batch_x_image = np.expand_dims(np.array(x_image_resize).astype(np.float32), axis = 0) 
                        batch_y_image = np.expand_dims(np.array(y_image_resize).astype(np.float32), axis = 0) 
                        feed_dict = { lr : lrate, x_img : batch_x_image, y_img : batch_y_image} 
                        encoder_loss_value, dis_loss_value, _ = sess.run([encoder_loss, dis_loss, train_op], feed_dict=feed_dict) 
                        if counter % args.save_pred_every == 0: 
                            save(saver, sess, args.snapshot_dir, counter)
                        if counter % args.summary_pred_every == 0: 
                            encoder_loss_sum_value, discriminator_sum_value = sess.run([encoder_loss_sum, discriminator_sum], feed_dict=feed_dict)
                            summary_writer.add_summary(encoder_loss_sum_value, counter)
                            summary_writer.add_summary(discriminator_sum_value, counter)
                        if counter % args.write_pred_every == 0:  
                            fake_y_value= sess.run(fake_y, feed_dict=feed_dict)
                            write_image = get_write_picture(x_image_resize, y_image_resize, fake_y_value)
                            write_image_name = args.out_dir + "/out"+ str(counter) + ".png" 
                            cv2.imwrite(write_image_name, write_image)
                        if counter % args.human_number == 0:
                            for continued in range(args.duration):
                                args.svalue = random.uniform(0.6, 0.9)
                                args.coefficient = 0.3
                            number += 1
                        print('epoch {:d} step {:d} \t encoder_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, encoder_loss_value, dis_loss_value))

            else:
                lrate = args.base_lr if epoch < args.epoch_step else args.base_lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 
                for step in range(len(x_datalists)): 
                    counter += 1
                    x_image_resize, y_image_resize = TrainImageReader(x_datalists, y_datalists, step, args.image_size) 
                    batch_x_image = np.expand_dims(np.array(x_image_resize).astype(np.float32), axis = 0) 
                    batch_y_image = np.expand_dims(np.array(y_image_resize).astype(np.float32), axis = 0) 
                    feed_dict = { lr : lrate, x_img : batch_x_image, y_img : batch_y_image} 
                    encoder_loss_value, dis_loss_value, _ = sess.run([encoder_loss, dis_loss, train_op], feed_dict=feed_dict) 
                    if counter % args.save_pred_every == 0: 
                        save(saver, sess, args.snapshot_dir, counter)
                    if counter % args.summary_pred_every == 0: 
                        encoder_loss_sum_value, discriminator_sum_value = sess.run([encoder_loss_sum, discriminator_sum], feed_dict=feed_dict)
                        summary_writer.add_summary(encoder_loss_sum_value, counter)
                        summary_writer.add_summary(discriminator_sum_value, counter)
                    if counter % args.write_pred_every == 0:  
                        fake_y_value= sess.run(fake_y, feed_dict=feed_dict)
                        write_image = get_write_picture(x_image_resize, y_image_resize, fake_y_value)
                        write_image_name = args.out_dir + "/out"+ str(counter) + ".png" 
                        cv2.imwrite(write_image_name, write_image)
                    print('epoch {:d} step {:d} \t encoder_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, encoder_loss_value, dis_loss_value))
Example #4
0
for epoch in tqdm_iter:
    loss_li = []
    mse_li = []
    diff_li = []
    for i, data in enumerate(train_loader, start=0):
        inputs, labels = (d.to('cuda') for d in data)
        #soft_labels = soft_transform(labels, std=0.1)

        #optimize
        opt.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        opt.step()

        diff = l1_loss(outputs.detach(), labels)
        mse_li.append(loss.item())
        diff_li.append(diff.item())

        if global_step % eval_per_iter == 0:
            mse_li_ = []
            diff_li_ = []
            for j, data_ in enumerate(test_loader, start=0):
                with torch.no_grad():
                    inputs_, labels_ = (d.to('cuda') for d in data_)
                    outputs_ = model(inputs_).detach()
                    mse_ = criterion(outputs_, labels_)
                    diff_ = l1_loss(outputs_, labels_)
                    mse_li_.append(mse_.item())
                    diff_li_.append(diff_.item())
    b2a2b = models.generator(b2a, 'a2b', reuse=True)
    a2b2a = models.generator(a2b, 'b2a', reuse=True)

    a_dis = models.discriminator(a_real, 'a')
    b2a_dis = models.discriminator(b2a, 'a', reuse=True)
    b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True)
    b_dis = models.discriminator(b_real, 'b')
    a2b_dis = models.discriminator(a2b, 'b', reuse=True)
    a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True)

    # losses
    g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)),
                             name='g_loss_a2b')
    g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)),
                             name='g_loss_b2a')
    cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * 10.0,
                             name='cyc_loss_a')
    cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * 10.0,
                             name='cyc_loss_b')
    g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b

    d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis))
    d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis,
                                    tf.zeros_like(b2a_sample_dis))
    d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample) / 2.0,
                           name='d_loss_a')
    d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis))
    d_loss_a2b_sample = ops.l2_loss(a2b_sample_dis,
                                    tf.zeros_like(a2b_sample_dis))
    d_loss_b = tf.identity((d_loss_b_real + d_loss_a2b_sample) / 2.0,
                           name='d_loss_b')
    a2b = models.generator(a_real, 'a2b')
    b2a = models.generator(b_real, 'b2a')
    b2a2b = models.generator(b2a, 'a2b', reuse=True)
    a2b2a = models.generator(a2b, 'b2a', reuse=True)

    a_dis = models.discriminator(a_real, 'a')
    b2a_dis = models.discriminator(b2a, 'a', reuse=True)
    b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True)
    b_dis = models.discriminator(b_real, 'b')
    a2b_dis = models.discriminator(a2b, 'b', reuse=True)
    a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True)

    # losses
    g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)), name='g_loss_a2b')
    g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)), name='g_loss_b2a')
    cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * 10.0, name='cyc_loss_a')
    cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * 10.0, name='cyc_loss_b')
    g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b

    d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis))
    d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis, tf.zeros_like(b2a_sample_dis))
    d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample) / 2.0, name='d_loss_a')
    d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis))
    d_loss_a2b_sample = ops.l2_loss(a2b_sample_dis, tf.zeros_like(a2b_sample_dis))
    d_loss_b = tf.identity((d_loss_b_real + d_loss_a2b_sample) / 2.0, name='d_loss_b')

    # summaries
    g_summary = ops.summary_tensors([g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b])
    d_summary_a = ops.summary(d_loss_a)
    d_summary_b = ops.summary(d_loss_b)
Example #7
0
def build_networks():
    with tf.device('/gpu:%d' % args.gpu_id):
        # Nodes
        a_real = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
        b_real = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
        a2b_sample = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
        b2a_sample = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])

        a2b1 = models.generator(a_real, 'a2b')
        b2a1 = models.generator(b_real, 'b2a')

        if args.transform_twice: #a-b-c
            a2b = models.generator(a2b1, 'a2b', reuse=True)
            b2a = models.generator(b2a1, 'b2a', reuse=True)
        else:
            a2b = a2b1
            b2a = b2a1

        b2a2b = models.generator(b2a, 'a2b', reuse=True)
        a2b2a = models.generator(a2b, 'b2a', reuse=True)
        
        if args.transform_twice: #a-b-c
            b2a2b = models.generator(b2a2b, 'a2b', reuse=True)
            a2b2a = models.generator(a2b2a, 'b2a', reuse=True)

        # Add extra loss term to enforce the discriminator's power to discern A samples from B samples

        a_dis = models.discriminator(a_real, 'a')
        a_from_b_dis = models.discriminator(b_real, 'a', reuse=True) #mod1

        b2a_dis = models.discriminator(b2a, 'a', reuse=True)
        b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True)
        b_dis = models.discriminator(b_real, 'b')
        b_from_a_dis = models.discriminator(a_real, 'b', reuse=True) #mod1

        a2b_dis = models.discriminator(a2b, 'b', reuse=True)
        a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True)

        double_cycle_loss = 0.0

        if args.double_cycle: #Now making these double-processed samples belong to the same domain as 1-processed. I.e. the domains are "reflexive".
            a2b_sample_dis2 = models.discriminator(models.generator(a2b_sample, 'a2b', reuse=True), 'b', reuse=True)
            b2a_sample_dis2 = models.discriminator(models.generator(b2a_sample, 'b2a', reuse=True), 'a', reuse=True)

            a2b2b = models.generator(a2b, 'a2b', reuse=True)
            a2b2b2a = models.generator(a2b2b, 'b2a', reuse=True)
            a2b2b2a2a = models.generator(a2b2b2a, 'b2a', reuse=True)
            b2a2a = models.generator(b2a, 'b2a', reuse=True)
            b2a2a2b = models.generator(b2a2a, 'a2b', reuse=True)
            b2a2a2b2b = models.generator(b2a2a2b, 'a2b', reuse=True)

            cyc_loss_a2 = tf.identity(ops.l1_loss(a_real, a2b2b2a2a) * 10.0, name='cyc_loss_a2')
            cyc_loss_b2 = tf.identity(ops.l1_loss(b_real, b2a2a2b2b) * 10.0, name='cyc_loss_b2')

            double_cycle_loss = cyc_loss_a2 + cyc_loss_b2

        # Losses
        g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)), name='g_loss_a2b')
        g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)), name='g_loss_b2a')
        cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * 10.0, name='cyc_loss_a')
        cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * 10.0, name='cyc_loss_b')
        g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b + double_cycle_loss

        d_loss_b2a_sample2 = d_loss_a2b_sample2 = 0.0

        d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis))
        d_loss_a_from_b_real = tf.identity(ops.l2_loss(a_from_b_dis, tf.zeros_like(a_from_b_dis)), name='d_loss_a_from_b') #mod1

        d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis, tf.zeros_like(b2a_sample_dis))
        if args.double_cycle:
            d_loss_b2a_sample2 = ops.l2_loss(b2a_sample_dis2, tf.zeros_like(b2a_sample_dis))

        d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample + d_loss_b2a_sample2 + d_loss_a_from_b_real) / 3.0, name='d_loss_a')
        d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis))
        d_loss_b_from_a_real = tf.identity(ops.l2_loss(b_from_a_dis, tf.zeros_like(b_from_a_dis)), name='d_loss_b_from_a') #mod1

        d_loss_a2b_sample = ops.l2_loss(a2b_sample_dis, tf.zeros_like(a2b_sample_dis))
        if args.double_cycle:
            d_loss_a2b_sample2 = ops.l2_loss(a2b_sample_dis2, tf.zeros_like(a2b_sample_dis))  

        d_loss_b = tf.identity((d_loss_b_real + d_loss_a2b_sample + d_loss_a2b_sample2 + d_loss_b_from_a_real) / 3.0, name='d_loss_b')

        # Summaries
        g_summary   = ops.summary_tensors([g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b])
        d_summary_a = ops.summary_tensors([d_loss_a, d_loss_a_from_b_real])
        d_summary_b = ops.summary_tensors([d_loss_b, d_loss_b_from_a_real])

        # Optim
        t_var = tf.trainable_variables()
        d_a_var = [var for var in t_var if 'a_discriminator' in var.name]
        d_b_var = [var for var in t_var if 'b_discriminator' in var.name]
        g_var   = [var for var in t_var if 'a2b_generator'   in var.name or 'b2a_generator' in var.name]

        d_a_train_op = tf.train.AdamOptimizer(args.lr, beta1=0.5).minimize(d_loss_a, var_list=d_a_var)
        d_b_train_op = tf.train.AdamOptimizer(args.lr, beta1=0.5).minimize(d_loss_b, var_list=d_b_var)
        g_train_op = tf.train.AdamOptimizer(args.lr, beta1=0.5).minimize(g_loss, var_list=g_var)

        return g_train_op, d_a_train_op, d_b_train_op, g_summary, d_summary_a, d_summary_b, a2b, a2b2a, b2a, b2a2b, a_real, b_real, a2b_sample, b2a_sample, a2b1, b2a1
    a2b = models.generator(a_real, 'a2b')
    b2a = models.generator(b_real, 'b2a')
    b2a2b = models.generator(b2a, 'a2b', reuse=True)
    a2b2a = models.generator(a2b, 'b2a', reuse=True)

    a_dis = models.discriminator(a_real, 'a')
    b2a_dis = models.discriminator(b2a, 'a', reuse=True)
    b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True)
    b_dis = models.discriminator(b_real, 'b')
    a2b_dis = models.discriminator(a2b, 'b', reuse=True)
    a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True)

    # losses
    g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)), name='g_loss_a2b')
    g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)), name='g_loss_b2a')
    cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * 10.0, name='cyc_loss_a')
    cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * 10.0, name='cyc_loss_b')
    g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a + cyc_loss_b

    d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis))
    d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis, tf.zeros_like(b2a_sample_dis))
    d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample) / 2.0, name='d_loss_a')
    d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis))
    d_loss_a2b_sample = ops.l2_loss(a2b_sample_dis, tf.zeros_like(a2b_sample_dis))
    d_loss_b = tf.identity((d_loss_b_real + d_loss_a2b_sample) / 2.0, name='d_loss_b')

    # summaries
    g_summary = ops.summary_tensors([g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b])
    d_summary_a = ops.summary(d_loss_a)
    d_summary_b = ops.summary(d_loss_b)