Ejemplo n.º 1
0
def train(dataset, gpu_id):

    params = param.getGeneralParams()
    gpu = '/gpu:' + str(gpu_id)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    with tf.device(gpu):
        vgg_model = myVGG.vgg_norm()
        networks.make_trainable(vgg_model, False)
        response_weights = sio.loadmat('mean_response.mat')
        fgbg = networks.network_fgbg(params, vgg_model, response_weights)
        #fgbg.load_weights('../results/networks/fgbg_vgg/140000.h5')
        disc = networks.discriminator(params)
        gan = networks.gan(fgbg, disc, params, vgg_model, response_weights,
                           0.01, 1e-4)
        gan.load_weights('../results/networks/fgbg_gan/2000.h5')

        outputs = [fgbg.outputs[0]]
        #outputs.append(fgbg.get_layer('mask_src').output)
        #outputs.append(fgbg.get_layer('fg_stack').output)
        #outputs.append(fgbg.get_layer('bg_src').output)
        #outputs.append(fgbg.get_layer('bg_tgt').output)
        #outputs.append(fgbg.get_layer('fg_tgt').output)
        outputs.append(fgbg.get_layer('fg_mask_tgt').output)
        model = Model(fgbg.inputs, outputs)

    test = datareader.makeActionExampleList('test_vids.txt', 1)
    feed = datageneration.warpExampleGenerator(test,
                                               params,
                                               do_augment=False,
                                               return_pose_vectors=True)

    n_frames = len(test)

    true_action = np.zeros((256, 256, 3, n_frames))
    pred_action = np.zeros((256, 256, 3, n_frames))
    mask = np.zeros((256, 256, 1, n_frames))

    for i in xrange(n_frames):
        print i
        X, Y = next(feed)
        pred = model.predict(X[:-2])
        true_action[:, :, :, i] = convert(np.reshape(Y, (256, 256, 3)))
        pred_action[:, :, :, i] = convert(np.reshape(pred[0], (256, 256, 3)))
        mask[:, :, :, i] = pred[1]

    sio.savemat('results/action/1_gan.mat', {
        'true': true_action,
        'pred': pred_action,
        'mask': mask
    })
Ejemplo n.º 2
0
def train(model_name, gpu_id):
    params = param.get_general_params()
    network_dir = params['model_save_dir'] + '/' + model_name

    if not os.path.isdir(network_dir):
        os.mkdir(network_dir)

    train_feed = data_generation.create_feed(params, params['data_dir'], 'train')

    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    gan_lr = 1e-4
    disc_lr = 1e-4
    disc_loss = 0.1

    generator = networks.network_posewarp(params)
    generator.load_weights('../models/vgg_100000.h5')

    discriminator = networks.discriminator(params)
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=disc_lr))

    vgg_model = truncated_vgg.vgg_norm()
    networks.make_trainable(vgg_model, False)
    response_weights = sio.loadmat('../data/vgg_activation_distribution_train.mat')

    gan = networks.gan(generator, discriminator, params)
    gan.compile(optimizer=Adam(lr=gan_lr),
                loss=[networks.vgg_loss(vgg_model, response_weights, 12), 'binary_crossentropy'],
                loss_weights=[1.0, disc_loss])

    n_iters = 10000
    batch_size = params['batch_size']

    for step in range(n_iters):

        x, y = next(train_feed)

        gen = generator.predict(x)

        # Train discriminator
        x_tgt_img_disc = np.concatenate((y, gen))
        x_src_pose_disc = np.concatenate((x[1], x[1]))
        x_tgt_pose_disc = np.concatenate((x[2], x[2]))

        L = np.zeros([2 * batch_size])
        L[0:batch_size] = 1

        inputs = [x_tgt_img_disc, x_src_pose_disc, x_tgt_pose_disc]
        d_loss = discriminator.train_on_batch(inputs, L)

        # Train the discriminator a couple of iterations before starting the gan
        if step < 5:
            util.printProgress(step, 0, [0, d_loss])
            step += 1
            continue

        # TRAIN GAN
        L = np.ones([batch_size])
        x, y = next(train_feed)
        g_loss = gan.train_on_batch(x, [y, L])
        util.printProgress(step, 0, [g_loss[1], d_loss])

        if step % params['model_save_interval'] == 0 and step > 0:
            gan.save(network_dir + '/' + str(step) + '.h5')
Ejemplo n.º 3
0
def train(model_name, gpu_id):

    params = param.getGeneralParams()
    gpu = '/gpu:' + str(gpu_id)

    network_dir = params['project_dir'] + '/results/networks/' + model_name

    if not os.path.isdir(network_dir):
        os.mkdir(network_dir)

    train_feed = datageneration.createFeed(params, "train_vids.txt")
    test_feed = datageneration.createFeed(params, "test_vids.txt")

    batch_size = params['batch_size']

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.Session(config=config) as sess:

        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        gan_lr = 1e-4
        disc_lr = 1e-4
        disc_loss = 0.1

        with tf.device(gpu):
            vgg_model = myVGG.vgg_norm()
            networks.make_trainable(vgg_model, False)
            response_weights = sio.loadmat('mean_response.mat')
            #generator = networks.network_pix2pix(params,vgg_model,response_weights)
            generator = networks.network_fgbg(params)
            generator.load_weights('../results/networks/fgbg_vgg/100000.h5')

            discriminator = networks.discriminator(params)
            discriminator.compile(loss='binary_crossentropy',
                                  optimizer=Adam(lr=disc_lr))
            gan = networks.gan(generator, discriminator, params, vgg_model,
                               response_weights, disc_loss, gan_lr)
            gan.compile(optimizer=Adam(lr=gan_lr),
                        loss=[
                            networks.vggLoss(vgg_model, response_weights),
                            'binary_crossentropy'
                        ],
                        loss_weights=[1.0, disc_loss])

        for step in xrange(10001):

            X, Y = next(train_feed)

            with tf.device(gpu):
                gen = generator.predict(X)  #[0:3])

            #Train discriminator
            X_tgt_img_disc = np.concatenate((Y, gen))
            X_src_pose_disc = np.concatenate((X[1], X[1]))
            X_tgt_pose_disc = np.concatenate((X[2], X[2]))

            L = np.zeros([2 * batch_size])
            L[0:batch_size] = 1

            inputs = [X_tgt_img_disc, X_src_pose_disc, X_tgt_pose_disc]
            d_loss = discriminator.train_on_batch(inputs, L)

            #Train the discriminator a couple of iterations before starting the gan
            if (step < 5):
                util.printProgress(step, 0, [0, d_loss])
                step += 1
                continue

            #TRAIN GAN
            L = np.ones([batch_size])
            X, Y = next(train_feed)
            g_loss = gan.train_on_batch(X, [Y, L])
            util.printProgress(step, 0, [g_loss[1], d_loss])
            '''
			#Test
			if(step % params['test_interval'] == 0):
				n_batches = 8
				test_loss = np.zeros(2)			
				for j in xrange(n_batches):	
					X,Y = next(warp_test_feed)
					#test_loss += np.array(generator.test_on_batch(X_warp,Y_warp))
					L = np.zeros([batch_size,2])
					L[:,1] = 1 #Fake images

					test_loss_j = gan_warp.test_on_batch(X_warp, [Y_warp,L])
					test_loss += np.array(test_loss_j[1:3])
	
				test_loss /= (n_batches)
				util.printProgress(step,1,test_loss)
			'''

            if (step % params['model_save_interval'] == 0 and step > 0):
                gan.save(network_dir + '/' + str(step) + '.h5')
Ejemplo n.º 4
0
def train(dataset, gpu_id):

    params = param.getGeneralParams()
    gpu = '/gpu:' + str(gpu_id)

    lift_params = param.getDatasetParams('weightlifting')
    golf_params = param.getDatasetParams('golfswinghd')
    workout_params = param.getDatasetParams('workout')
    tennis_params = param.getDatasetParams('tennis')
    aux_params = param.getDatasetParams('test-aux')

    _, lift_test = datareader.makeWarpExampleList(lift_params, 0, 2000, 2, 1)
    _, golf_test = datareader.makeWarpExampleList(golf_params, 0, 5000, 2, 2)
    _, workout_test = datareader.makeWarpExampleList(workout_params, 0, 2000,
                                                     2, 3)
    _, tennis_test = datareader.makeWarpExampleList(tennis_params, 0, 2000, 2,
                                                    4)
    _, aux_test = datareader.makeWarpExampleList(aux_params, 0, 2000, 2, 5)

    test = lift_test + golf_test + workout_test + tennis_test + aux_test
    feed = datageneration.warpExampleGenerator(test,
                                               params,
                                               do_augment=False,
                                               draw_skeleton=False,
                                               skel_color=(0, 0, 255),
                                               return_pose_vectors=True)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.Session(config=config) as sess:

        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        with tf.device(gpu):
            vgg_model = myVGG.vgg_norm()
            networks.make_trainable(vgg_model, False)
            response_weights = sio.loadmat('mean_response.mat')

            gen = networks.network_fgbg(params,
                                        vgg_model,
                                        response_weights,
                                        True,
                                        loss='vgg')
            disc = networks.discriminator(params)
            gan = networks.gan(gen, disc, params, vgg_model, response_weights,
                               0.01, 1e-4)
            gan.load_weights('../results/networks/gan/10000.h5')

        np.random.seed(17)
        n_batches = 25
        for j in xrange(n_batches):
            print j
            X, Y = next(feed)
            loss = gen.evaluate(X[0:-2], Y)
            pred = gen.predict(X[0:-2])

            sio.savemat(
                'results/outputs/' + str(j) + '.mat', {
                    'X': X[0],
                    'Y': Y,
                    'pred': pred,
                    'loss': loss,
                    'src_pose': X[-2],
                    'tgt_pose': X[-1]
                })
Ejemplo n.º 5
0
def train(model_name, gpu_id):

    params = param.getGeneralParams()
    gpu = '/gpu:' + str(gpu_id)

    network_dir = params['project_dir'] + '/results/networks/' + model_name

    if not os.path.isdir(network_dir):
        os.mkdir(network_dir)

    train_feed = datageneration.createFeed(params, "train_vids.txt", 50000)
    test_feed = datageneration.createFeed(params, "test_vids.txt", 5000)

    batch_size = params['batch_size']

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    gan_lr = 5e-5
    disc_lr = 5e-5
    disc_loss = 0.1

    vgg_model_num = 184000

    with tf.device(gpu):
        vgg_model = myVGG.vgg_norm()
        networks.make_trainable(vgg_model, False)
        response_weights = sio.loadmat('mean_response.mat')
        generator = networks.network_fgbg(params, vgg_model, response_weights)
        generator.load_weights('../results/networks/fgbg_vgg_new/' +
                               str(vgg_model_num) + '.h5')

        discriminator = networks.discriminator(params)
        discriminator.compile(loss=networks.wass, optimizer=RMSprop(disc_lr))
        gan = networks.gan(generator, discriminator, params, vgg_model,
                           response_weights, disc_loss, gan_lr)

    for step in xrange(vgg_model_num + 1, vgg_model_num + 5001):
        for j in xrange(2):
            for l in discriminator.layers:
                weights = l.get_weights()
                weights = [np.clip(w, -0.01, 0.01) for w in weights]
                l.set_weights(weights)

            X, Y = next(train_feed)

            with tf.device(gpu):
                gen = generator.predict(X)

            #Train discriminator
            networks.make_trainable(discriminator, True)

            X_tgt_img_disc = np.concatenate((Y, gen))
            X_src_pose_disc = np.concatenate((X[1], X[1]))
            X_tgt_pose_disc = np.concatenate((X[2], X[2]))

            L = np.ones(2 * batch_size)
            L[0:batch_size] = -1

            inputs = [X_tgt_img_disc, X_src_pose_disc, X_tgt_pose_disc]
            d_loss = discriminator.train_on_batch(inputs, L)
            networks.make_trainable(discriminator, False)

        #TRAIN GAN
        L = -1 * np.ones(batch_size)
        X, Y = next(train_feed)
        g_loss = gan.train_on_batch(X, [Y, L])
        util.printProgress(step, 0, [g_loss[1], d_loss])

        if (step % params['model_save_interval'] == 0):
            gan.save(network_dir + '/' + str(step) + '.h5')
        '''
Ejemplo n.º 6
0
def train(dataset,gpu_id):	

	params = param.getGeneralParams()
	gpu = '/gpu:' + str(gpu_id)

	np.random.seed(17)
	feed = datageneration.createFeed(params,'test_vids.txt',5000,False,True,True)
	
	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True
	config.allow_soft_placement = True
	set_session(tf.Session(config=config))
	
	with tf.device(gpu):
		vgg_model = myVGG.vgg_norm()
		networks.make_trainable(vgg_model,False)
		response_weights = sio.loadmat('mean_response.mat')

		#fgbg_vgg = networks.network_fgbg(params,vgg_model,response_weights)
		#fgbg_vgg.load_weights('../results/networks/fgbg_vgg/184000.h5')	

		#gen = networks.network_fgbg(params,vgg_model,response_weights)
		#disc = networks.discriminator(params)
		#gan = networks.gan(gen,disc,params,vgg_model,response_weights,0.1,1e-4)
		#gan.load_weights('../results/networks/fgbg_gan/7000.h5')
			
		#fgbg_l1 = networks.network_fgbg(params,vgg_model,response_weights,loss='l1')
		#fgbg_l1.load_weights('../results/networks/fgbg_l1/100000.h5')	
	
		#mask_model = Model(fgbg_vgg.inputs,fgbg_vgg.get_layer('fg_mask_tgt').output)

		ed_vgg = networks.network_pix2pix(params,vgg_model,response_weights)
		ed_vgg.load_weights('../results/networks/ed_vgg/135000.h5')	

		gen = networks.network_pix2pix(params,vgg_model,response_weights)
		disc = networks.discriminator(params)
		gan = networks.gan(gen,disc,params,vgg_model,response_weights,0.1,1e-4)
		gan.load_weights('../results/networks/ed_gan/2000.h5')
			
		ed_l1 = networks.network_pix2pix(params,vgg_model,response_weights,loss='l1')
		ed_l1.load_weights('../results/networks/ed_l1/80000.h5')	
	
	n_examples = 500
	
	metrics = np.zeros((n_examples,9))
	poses = np.zeros((n_examples,28*2))
	classes = np.zeros(n_examples)

	for j in xrange(n_examples):	
		print j
		X,Y = next(feed)		

		pred_l1 = ed_l1.predict(X[:3]) #X[:-3])
		pred_vgg = ed_vgg.predict(X[:3]) #X[:-3])
		pred_gan = gen.predict(X[:3]) #[:-3])

		'''
		#mask = mask_model.predict(X[:-3])
		pred_l1_fg = pred_l1 * mask
		pred_vgg_fg = pred_vgg * mask
		pred_gan_fg = pred_gan * mask

		pred_l1_bg = pred_l1 * (1-mask)
		pred_vgg_bg = pred_vgg * (1-mask)
		pred_gan_bg = pred_gan * (1-mask)

		Y_fg = Y * mask
		Y_bg = Y * (1-mask)
		'''


		#,pred_l1_fg,pred_vgg_fg,pred_gan_fg,pred_l1_bg,pred_vgg_bg,pred_gan_bg]
		#,Y_fg,Y_fg,Y_fg,Y_bg,Y_bg,Y_bg]	
		preds = [pred_l1,pred_vgg,pred_gan] 
		targets = [Y,Y,Y] 
	
		metrics[j,0:3] = [l1Error(preds[i],targets[i]) for i in xrange(len(preds))]
		metrics[j,3:6] = [vggError(vgg_model.predict(util.vgg_preprocess(preds[i])),
									vgg_model.predict(util.vgg_preprocess(targets[i])),response_weights) for i in xrange(len(preds))]
		metrics[j,6:] = [ssimError(preds[i],targets[i]) for i in xrange(len(preds))]
		poses[j,0:28] = X[-3]
		poses[j,28:] = X[-2]
		classes[j] = int(X[-1])
		sio.savemat('results/comparison/pix2pix/' + str(j) + '.mat', {'X': X[0], 'Y': Y, 'pred_l1': pred_l1, 'pred_vgg': pred_vgg, 'pred_gan': pred_gan})
		sio.savemat('results/comparison_pix2pix.mat',{'metrics': metrics, 'poses': poses, 'classes': classes})
def train(model_name, gpu_id):

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    with tf.Session(config=tf_config) as sess:

        params = param.get_general_params()
        network_dir = params['model_save_dir'] + '/' + model_name

        if not os.path.isdir(network_dir):
            os.mkdir(network_dir)

        train_feed = data_generation.create_feed(params, params['data_dir'], "train")
        # test_feed = data_generation.create_feed(params,  params['data_dir'], "test")

        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

        gan_lr  = 1e-3
        disc_lr = 1e-3
        disc_loss = 0.1

        generator = networks.network_posewarp(params)
        # generator.load_weights('../models/posewarp_vgg/100000.h5')

        discriminator = networks.discriminator(params)
        discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=disc_lr))

        vgg_model = truncated_vgg.vgg_norm()
        networks.make_trainable(vgg_model, False)
        response_weights = sio.loadmat('../Models/vgg_activation_distribution_train.mat')

        gan = networks.gan(generator, discriminator, params)

        gan.compile(optimizer=Adam(lr=gan_lr),
                    loss=[networks.vgg_loss(vgg_model, response_weights, 12), 'binary_crossentropy'],
                    loss_weights=[1.0, disc_loss])

        n_iters = params['n_training_iter']
        batch_size = params['batch_size']

        summary_writer = tf.summary.FileWriter("./logs", graph=sess.graph)

        tr_x, tr_y = next(train_feed)
        # te_x, te_y = next(test_feed)

        # Prepare output directories if they don't exist.
        output_dir = '../Output/' + model_name + '/'

        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)

        scipy.misc.imsave(output_dir + 'tr_orig_image.png', tr_x[0][0, :, :, :])
        scipy.misc.imsave(output_dir + 'tr_targ_image.png', tr_y[0, :, :, :])
        # scipy.misc.imsave(output_dir + 'te_orig_image.png', te_x[0][0, :, :, :])
        # scipy.misc.imsave(output_dir + 'te_targ_image.png', te_y[0, :, :, :])

        print("Batch size: " + str(batch_size))

        for step in range(n_iters):

            x, y = next(train_feed)

            gen = generator.predict(x)

            # Train discriminator
            x_tgt_img_disc  = np.concatenate((y, gen))
            x_src_pose_disc = np.concatenate((x[1], x[1]))
            x_tgt_pose_disc = np.concatenate((x[2], x[2]))

            L = np.zeros([2 * batch_size])
            L[0:batch_size] = 1

            inputs = [x_tgt_img_disc, x_src_pose_disc, x_tgt_pose_disc]
            d_loss = discriminator.train_on_batch(inputs, L)

            # Train the discriminator a couple of iterations before starting the gan
            if step < 5:
                util.printProgress(step, 0, [0, d_loss])
                step += 1
                continue

            # TRAIN GAN
            L = np.ones([batch_size])
            x, y = next(train_feed)
            g_loss = gan.train_on_batch(x, [y, L])
            util.printProgress(step, 0, [g_loss[1], d_loss])

            if step % params['test_interval'] == 0:

                print(gen[0])

                gen = tf.get_default_graph().get_tensor_by_name("model_1/add_2_1/add:0")
                inp = tf.get_default_graph().get_tensor_by_name("in_img0:0")
                out = tf.get_default_graph().get_tensor_by_name("in_img1:0")
                p_s = tf.get_default_graph().get_tensor_by_name("mask_src/truediv:0")
                # p_t = tf.get_default_graph().get_tensor_by_name("in_pose1:0")


                image_summary_1 = tf.summary.image('images', [inp[0, :, :, :], out[0, :, :, :], gen[0, :, :, :]], max_outputs=100)
                # image_summary_2 = tf.summary.image('pose', [tf.reduce_sum(p_s[0, :, :, :], 2, keepdims=True)], max_outputs=100)

                image_summary_1 = sess.run(image_summary_1,feed_dict={"in_img0:0": x[0], "in_pose0:0": x[1], "in_pose1:0": x[2],
                                                                      "mask_prior:0": x[3], "trans_in:0": x[4], "in_img1:0": y,
                                                                      "input_3:0": x[0], "input_4:0": x[1], "input_5:0": x[2],
                                                                      "input_6:0": x[3], "input_7:0": x[4]})
                #
                # img_gen =  sess.run(image_summary_1,feed_dict={"in_img0:0": x[0], "in_pose0:0": x[1], "in_pose1:0": x[2],
                #                                                "mask_prior:0": x[3], "trans_in:0": x[4], "in_img1:0": y,
                #                                                "input_3:0": x[0], "input_4:0": x[1], "input_5:0": x[2],
                #                                                "input_6:0": x[3], "input_7:0": x[4]})


                # image_summary_2 = sess.run(image_summary_2, feed_dict={"in_img0:0" : x[0], "in_pose0:0" : x[1], "in_pose1:0" : x[2],
                #                                                     "mask_prior:0" : x[3], "trans_in:0" : x[4], "in_img1:0"  : y})

                summary_writer.add_summary(image_summary_1)
                # summary_writer.add_summary(image_summary_2)

                train_image = sess.run(gen, feed_dict={"in_img0:0": tr_x[0], "in_pose0:0": tr_x[1], "in_pose1:0": tr_x[2],
                                                       "mask_prior:0": tr_x[3], "trans_in:0": tr_x[4], "in_img1:0": tr_y,
                                                       "input_3:0": tr_x[0], "input_4:0": tr_x[1], "input_5:0": tr_x[2],
                                                       "input_6:0": tr_x[3], "input_7:0": tr_x[4]})
                #
                # test_image = sess.run(gen, feed_dict={"in_img0:0": te_x[0], "in_pose0:0": te_x[1], "in_pose1:0": te_x[2],
                #                                       "mask_prior:0": te_x[3], "trans_in:0": te_x[4], "in_img1:0": te_y,
                #                                       "input_3:0": te_x[0], "input_4:0": te_x[1], "input_5:0": te_x[2],
                #                                       "input_6:0": te_x[3], "input_7:0": te_x[4]})


                scipy.misc.imsave(output_dir + 'tr' + str(step) + ".png", train_image[0, :, :, :])
                # scipy.misc.imsave(output_dir + 'te' + str(step) + ".png", test_image[0, :, :, :])

            if step % params['model_save_interval'] == 0 and step > 0:
                gan.save(network_dir + '/' + str(step) + '.h5')
Ejemplo n.º 8
0
def test(model_name, gpu_id):
    params = param.get_general_params()
    network_dir = params['model_save_dir'] + '/' + model_name

    # if not os.path.isdir(network_dir):
    #     os.mkdir(network_dir)

    train_feed = data_generation.create_feed(params,
                                             params['data_dir'],
                                             'test',
                                             do_augment=False)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    gan_lr = 1e-4
    disc_lr = 1e-4
    disc_loss = 0.1

    generator = networks.network_posewarp(params)
    # generator.load_weights('../models/vgg_100000.h5')
    generator.load_weights(
        '/versa/kangliwei/motion_transfer/posewarp-cvpr2018/models/0301_fullfinetune/9000.h5'
    )

    mask_delta_model = Model(input=generator.input,
                             output=generator.get_layer('mask_delta').output)
    src_mask_model = Model(input=generator.input,
                           output=generator.get_layer('mask_src').output)

    discriminator = networks.discriminator(params)
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=Adam(lr=disc_lr))

    vgg_model = truncated_vgg.vgg_norm()
    networks.make_trainable(vgg_model, False)
    response_weights = sio.loadmat(
        '../data/vgg_activation_distribution_train.mat')

    gan = networks.gan(generator, discriminator, params)
    gan.compile(optimizer=Adam(lr=gan_lr),
                loss=[
                    networks.vgg_loss(vgg_model, response_weights, 12),
                    'binary_crossentropy'
                ],
                loss_weights=[1.0, disc_loss])

    n_iters = 10000
    batch_size = params['batch_size']

    for step in range(n_iters):

        x, y = next(train_feed)

        gen = generator.predict(x)

        src_mask_delta = mask_delta_model.predict(x)
        print('delta_max', src_mask_delta.max())
        src_mask_delta = src_mask_delta * 255
        src_mask = src_mask_model.predict(x)
        print('mask_max', src_mask.max())
        src_mask = src_mask * 255
        # print('src_mask_delta', type(src_mask_delta), src_mask_delta.shape)

        y = (y / 2 + 0.5) * 255.0
        gen = (gen / 2 + 0.5) * 255.0
        for i in range(gen.shape[0]):  # iterate in batch
            cv2.imwrite('pics/src' + str(i) + '.jpg', x[0][i] * 255)
            cv2.imwrite('pics/gen' + str(i) + '.jpg', gen[i])
            cv2.imwrite('pics/y' + str(i) + '.jpg', y[i])
            for j in range(11):
                cv2.imwrite('pics/seg_delta_' + str(i) + '_' + str(j) + '.jpg',
                            src_mask_delta[i][:, :, j])
            for j in range(11):
                cv2.imwrite('pics/seg_' + str(i) + '_' + str(j) + '.jpg',
                            src_mask[i][:, :, j])
        break

        # Train discriminator
        x_tgt_img_disc = np.concatenate((y, gen))
        x_src_pose_disc = np.concatenate((x[1], x[1]))
        x_tgt_pose_disc = np.concatenate((x[2], x[2]))

        L = np.zeros([2 * batch_size])
        L[0:batch_size] = 1

        inputs = [x_tgt_img_disc, x_src_pose_disc, x_tgt_pose_disc]
        d_loss = discriminator.train_on_batch(inputs, L)

        # Train the discriminator a couple of iterations before starting the gan
        if step < 5:
            util.printProgress(step, 0, [0, d_loss])
            step += 1
            continue

        # TRAIN GAN
        L = np.ones([batch_size])
        x, y = next(train_feed)
        g_loss = gan.train_on_batch(x, [y, L])
        util.printProgress(step, 0, [g_loss[1], d_loss])

        if step % params['model_save_interval'] == 0 and step > 0:
            generator.save(network_dir + '/' + str(step) + '.h5')