Exemplo n.º 1
0
def _handler(ir_path, vis_path, model_path, model_pre_path, ssim_weight, index, output_path=None):
	ir_img = get_train_images(ir_path, flag=False)
	vis_img = get_train_images(vis_path, flag=False)
	# ir_img = get_train_images_rgb(ir_path, flag=False)
	# vis_img = get_train_images_rgb(vis_path, flag=False)
	dimension = ir_img.shape

	ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]])
	vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]])

	ir_img = np.transpose(ir_img, (0, 2, 1, 3))
	vis_img = np.transpose(vis_img, (0, 2, 1, 3))

	print('img shape final:', ir_img.shape)

	with tf.Graph().as_default(), tf.Session() as sess:
		infrared_field = tf.placeholder(
			tf.float32, shape=ir_img.shape, name='content')
		visible_field = tf.placeholder(
			tf.float32, shape=ir_img.shape, name='style')

		dfn = DenseFuseNet(model_pre_path)

		output_image = dfn.transform_addition(infrared_field, visible_field)
		# restore the trained model and run the style transferring
		saver = tf.train.Saver()
		saver.restore(sess, model_path)

		output = sess.run(output_image, feed_dict={infrared_field: ir_img, visible_field: vis_img})

		save_images(ir_path, output, output_path,
		            prefix='fused' + str(index), suffix='_densefuse_addition_'+str(ssim_weight))
Exemplo n.º 2
0
def _handler_l1(ir_path,
                vis_path,
                model_path,
                model_pre_path,
                ssim_weight,
                index,
                output_path=None):
    ir_img = get_train_images(ir_path, flag=False)
    vis_img = get_train_images(vis_path, flag=False)
    dimension = ir_img.shape

    ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]])
    vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]])

    ir_img = np.transpose(ir_img, (0, 2, 1, 3))
    vis_img = np.transpose(vis_img, (0, 2, 1, 3))

    print('img shape final:', ir_img.shape)

    with tf.Graph().as_default(), tf.Session() as sess:

        # build the dataflow graph
        infrared_field = tf.placeholder(tf.float32,
                                        shape=ir_img.shape,
                                        name='content')
        visible_field = tf.placeholder(tf.float32,
                                       shape=ir_img.shape,
                                       name='style')

        dfn = DenseFuseNet(model_pre_path)

        enc_ir = dfn.transform_encoder(infrared_field)
        enc_vis = dfn.transform_encoder(visible_field)

        target = tf.placeholder(tf.float32, shape=enc_ir.shape, name='target')

        output_image = dfn.transform_decoder(target)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        enc_ir_temp, enc_vis_temp = sess.run([enc_ir, enc_vis],
                                             feed_dict={
                                                 infrared_field: ir_img,
                                                 visible_field: vis_img
                                             })
        feature = L1_norm(enc_ir_temp, enc_vis_temp)

        output = sess.run(output_image, feed_dict={target: feature})
        save_images(ir_path,
                    output,
                    output_path,
                    prefix='fused' + str(index),
                    suffix='_densefuse_l1norm_' + str(ssim_weight))
Exemplo n.º 3
0
def _get_attention(ir_path,vis_path,model_path_a,model_pre_path_a):
	ir_img = get_train_images(ir_path, flag=False)
	vis_img = get_train_images(vis_path, flag=False)
	dimension = ir_img.shape
	ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]])
	vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]])
	ir_img = np.transpose(ir_img, (0, 2, 1, 3))
	vis_img = np.transpose(vis_img, (0, 2, 1, 3))
	g1 = tf.Graph()  # 加载到Session 1的graph

	sess1 = tf.Session(graph=g1)  # Session1

	with sess1.as_default():
		with g1.as_default(), tf.Session() as sess:
			infrared_field = tf.placeholder(
				tf.float32, shape=ir_img.shape, name='content')
			visible_field = tf.placeholder(
				tf.float32, shape=vis_img.shape, name='style')
			edge_ir = tf.placeholder(tf.float32, shape=ir_img.shape, name='attention')
			edge_vis = tf.placeholder(tf.float32, shape=ir_img.shape, name='attention')

			# -----------------------------------------------
			image_ir = sess.run(infrared_field, feed_dict={infrared_field: ir_img})
			image_vis = sess.run(visible_field, feed_dict={visible_field: vis_img})

			p_vis = image_vis[0]
			p_ir = image_ir[0]

			p_vis = np.squeeze(p_vis)  # 降维
			p_ir = np.squeeze(p_ir)

			guideFilter_img_vis = Grad(p_vis)
			guideFilter_img_ir = Grad(p_ir)

			guideFilter_img_vis[guideFilter_img_vis < 0] = 0
			guideFilter_img_ir[guideFilter_img_ir < 0] = 0
			guideFilter_img_vis = np.expand_dims(guideFilter_img_vis, axis=-1)
			guideFilter_img_ir = np.expand_dims(guideFilter_img_ir, axis=-1)
			guideFilter_img_vis = np.expand_dims(guideFilter_img_vis, axis=0)
			guideFilter_img_ir = np.expand_dims(guideFilter_img_ir, axis=0)

			a = attention.Attention(model_pre_path_a)
			saver = tf.train.Saver()
			saver.restore(sess, model_path_a)

			feature_a=a.get_attention(edge_ir)
			feature_b=a.get_attention(edge_vis)


			edge_ir_temp = sess.run([feature_a], feed_dict={edge_ir: guideFilter_img_ir})
			edge_vis_temp = sess.run([feature_b], feed_dict={edge_vis: guideFilter_img_vis})
			'''feature_a = a.get_attention(edge_ir_temp)
			feature_b = a.get_attention(edge_vis_temp)'''

			return  edge_ir_temp,edge_vis_temp
Exemplo n.º 4
0
def _handler(content_path,
             style_path,
             encoder_path,
             model_path,
             model_pre_path,
             output_path=None):

    with tf.Graph().as_default(), tf.Session() as sess:
        index = 2
        content_path = content_path + str(index) + '.jpg'
        style_path = style_path + str(index) + '.jpg'

        content_img = get_train_images(content_path)
        style_img = get_train_images(style_path)

        # build the dataflow graph
        content = tf.placeholder(tf.float32,
                                 shape=content_img.shape,
                                 name='content')
        style = tf.placeholder(tf.float32, shape=style_img.shape, name='style')

        stn = StyleTransferNet(encoder_path, model_pre_path)

        enc_c, enc_s = stn.encoder_process(content, style)

        target = tf.placeholder(tf.float32, shape=enc_c.shape, name='target')

        # output_image = stn.transform(content, style)
        output_image = stn.decoder_process(target)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        # get the output
        enc_c, enc_s = sess.run([enc_c, enc_s],
                                feed_dict={
                                    content: content_img,
                                    style: style_img
                                })
        feature = L1_Max(enc_c, enc_s)
        # feature = enc_s
        output = sess.run(output_image, feed_dict={target: feature})

    if output_path is not None:
        save_images(content_path,
                    output,
                    output_path,
                    prefix='fused' + str(index) + '_',
                    suffix='deep')

    return output
Exemplo n.º 5
0
def _handler(content_name,
             style_name,
             model_path,
             model_pre_path,
             index,
             output_path=None):
    content_path = content_name
    style_path = style_name

    content_img = get_train_images(content_path, flag=False)
    style_img = get_train_images(style_path, flag=False)
    dimension = content_img.shape

    content_img = content_img.reshape(
        [1, dimension[0], dimension[1], dimension[2]])
    style_img = style_img.reshape(
        [1, dimension[0], dimension[1], dimension[2]])

    content_img = np.transpose(content_img, (0, 2, 1, 3))
    style_img = np.transpose(style_img, (0, 2, 1, 3))
    print('content_img shape final:', content_img.shape)

    with tf.Graph().as_default(), tf.Session() as sess:

        # build the dataflow graph
        content = tf.placeholder(tf.float32,
                                 shape=content_img.shape,
                                 name='content')
        style = tf.placeholder(tf.float32, shape=style_img.shape, name='style')

        dfn = DeepFuseNet(model_pre_path)

        output_image = dfn.transform_addition(content, style)
        # output_image = dfn.transform_recons(style)
        # output_image = dfn.transform_recons(content)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        output = sess.run(output_image,
                          feed_dict={
                              content: content_img,
                              style: style_img
                          })
        save_images(content_path,
                    output,
                    output_path,
                    prefix='fused' + str(index),
                    suffix='_deepfuse_bs2_epoch2')

    return output
Exemplo n.º 6
0
def _handler_video(ir_path, vis_path, model_path, model_pre_path, ssim_weight, output_path=None):
	infrared = ir_path[0]
	img = get_train_images(infrared, flag=False)
	img = img.reshape([1, img.shape[0], img.shape[1], img.shape[2]])
	img = np.transpose(img, (0, 2, 1, 3))
	print('img shape final:', img.shape)
	num_imgs = len(ir_path)

	with tf.Graph().as_default(), tf.Session() as sess:
		# build the dataflow graph
		infrared_field = tf.placeholder(
			tf.float32, shape=img.shape, name='content')
		visible_field = tf.placeholder(
			tf.float32, shape=img.shape, name='style')

		dfn = DenseFuseNet(model_pre_path)

		output_image = dfn.transform_addition(infrared_field, visible_field)

		# restore the trained model and run the style transferring
		saver = tf.train.Saver()
		saver.restore(sess, model_path)

		##################GET IMAGES###################################################################################
		start_time = datetime.now()
		for i in range(num_imgs):
			print('image number:', i)
			infrared = ir_path[i]
			visible = vis_path[i]

			ir_img = get_train_images(infrared, flag=False)
			vis_img = get_train_images(visible, flag=False)
			dimension = ir_img.shape

			ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]])
			vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]])

			ir_img = np.transpose(ir_img, (0, 2, 1, 3))
			vis_img = np.transpose(vis_img, (0, 2, 1, 3))

			################FEED########################################
			output = sess.run(output_image, feed_dict={infrared_field: ir_img, visible_field: vis_img})
			save_images(infrared, output, output_path,
			            prefix='fused' + str(i), suffix='_addition_' + str(ssim_weight))
			######################################################################################################
		elapsed_time = datetime.now() - start_time
		print('Dense block video==> elapsed time: %s' % (elapsed_time))
Exemplo n.º 7
0
def _handler(ir_path,
             vis_path,
             model_path,
             model_pre_path,
             ssim_weight,
             index,
             output_path=None):
    ir_img = get_train_images(ir_path, flag=False)
    vis_img = get_train_images(vis_path, flag=False)
    # ir_img = get_train_images_rgb(ir_path, flag=False)
    # vis_img = get_train_images_rgb(vis_path, flag=False)
    dimension = ir_img.shape

    ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]])
    vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]])

    ir_img = np.transpose(ir_img, (0, 2, 1, 3))
    vis_img = np.transpose(vis_img, (0, 2, 1, 3))

    print('img shape final:', ir_img.shape)

    with tf.Graph().as_default(), tf.Session() as sess:
        infrared_field = tf.placeholder(tf.float32,
                                        shape=ir_img.shape,
                                        name='content')
        visible_field = tf.placeholder(tf.float32,
                                       shape=ir_img.shape,
                                       name='style')

        dfn = DenseFuseNet(model_pre_path)

        output_image = dfn.transform_addition(infrared_field, visible_field)
        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        output = sess.run(output_image,
                          feed_dict={
                              infrared_field: ir_img,
                              visible_field: vis_img
                          })

        save_images(ir_path, output, output_path, prefix='3', suffix='')
        img333 = Image.open('/home/bingyang/wby/1/3.png')
        imgout = img333.transpose(Image.FLIP_LEFT_RIGHT)
        imgout = imgout.transpose(Image.ROTATE_90)
        imgout.save('/home/bingyang/wby/1/3.png')
Exemplo n.º 8
0
def _handler_l1(content_name, style_name, model_path, model_pre_path, ssim_weight, index, output_path=None):
    infrared_path = content_name
    visible_path = style_name

    content_img = get_train_images(infrared_path, flag=False)
    style_img   = get_train_images(visible_path, flag=False)
    dimension = content_img.shape

    content_img = content_img.reshape([1, dimension[0], dimension[1], dimension[2]])
    style_img   = style_img.reshape([1, dimension[0], dimension[1], dimension[2]])

    content_img = np.transpose(content_img, (0, 2, 1, 3))
    style_img = np.transpose(style_img, (0, 2, 1, 3))
    print('content_img shape final:', content_img.shape)

    with tf.Graph().as_default(), tf.Session() as sess:

        # build the dataflow graph
        content = tf.placeholder(
            tf.float32, shape=content_img.shape, name='content')
        style = tf.placeholder(
            tf.float32, shape=style_img.shape, name='style')

        dfn = DenseFuseNet(model_pre_path)

        enc_c = dfn.transform_encoder(content)
        enc_s = dfn.transform_encoder(style)

        target = tf.placeholder(
            tf.float32, shape=enc_c.shape, name='target')

        output_image = dfn.transform_decoder(target)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        enc_c_temp, enc_s_temp = sess.run([enc_c, enc_s], feed_dict={content: content_img, style: style_img})
        feature = L1_norm(enc_c_temp, enc_s_temp)

        output = sess.run(output_image, feed_dict={target: feature})
        save_images(infrared_path, output, output_path,
                    prefix='fused' + str(index), suffix='_densefuse_l1norm_'+str(ssim_weight))

    return output
def test_utils_imoperations():
    from utils import imread, imresize_square, get_train_images, imsave
    path_read = '/tmp/panda.jpg'
    path_save = '/tmp/panda_resized.jpg'
    # image = imread(path_read, mode='RGB')
    # image = imresize_square(image, long_side=256, interp = 'nearest')
    # imsave(path_save, image)
    images = get_train_images([path_read])
    imsave(path_save, images[0])
Exemplo n.º 10
0
def run(gpuids, q):
    # scan all files under img_path
    names = get_train_images()

    # init scheduler
    x = Scheduler(gpuids, q)

    # start processing and wait for complete
    return x.start(names)
Exemplo n.º 11
0
def train_recons(inputPath, validationPath, save_path, model_pre_path, EPOCHES_set, BATCH_SIZE, debug=False, logging_period=1):
    from datetime import datetime
    start_time = datetime.now()
    path = './models/performanceData/'
    fileName = 'TrainPerformanceData_'+str(start_time)+'.txt'
    fileName = fileName.replace(" ", "_")
    fileName = fileName.replace(":", "_")
    file = open(path+fileName, 'w')
    file.close()
    folders = list_folders(inputPath)
    valFolders = list_folders(validationPath)
    EPOCHS = EPOCHES_set
    print("EPOCHES   : ", EPOCHS)
    print("BATCH_SIZE: ", BATCH_SIZE)
    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR
    INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR)
    GROUNDTRUTH_SHAPE_OR = (1, HEIGHT_OR, WIDTH_OR, CHANNELS_OR)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        original = tf.placeholder(tf.float32, shape=INPUT_SHAPE_OR, name='original')
        groundtruth = tf.placeholder(tf.float32, shape=GROUNDTRUTH_SHAPE_OR, name='groundtruth')
        source = original

        print('source  :', source.shape)
        print('original:', original.shape)
        print('groundtruth:', groundtruth.shape)
        # create the deepfuse net (encoder and decoder)
        dfn = DenseFuseNet(model_pre_path)
        generated_img = dfn.transform_recons_train(source)
        print('generate:', generated_img.shape)
        pixel_loss = tf.reduce_sum(tf.square(groundtruth - generated_img))
        pixel_loss = tf.math.sqrt(pixel_loss / (BATCH_SIZE * HEIGHT * WIDTH))
        loss = pixel_loss
        train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)

        # ** Start Training **
        step = 0
        count_loss = 0
        numTrainSite = len(folders)
        numValSite = len(valFolders)

        for epoch in range(EPOCHS):
            save_path_epoc = './models_intermediate/'+str(epoch)+'.ckpt'
            start_time_epoc = datetime.now()
            for site in range(numTrainSite):
                start_time_site = datetime.now()
                file = open(path + fileName, 'a')
                groundtruth_imgs_path = list_images(inputPath+folders[site] + '/gt/')
                training_imgs_path = list_images(inputPath+folders[site])
                np.random.shuffle(training_imgs_path)
                gt = get_train_images(groundtruth_imgs_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False)
                gtImgTrain = np.zeros(GROUNDTRUTH_SHAPE_OR)
                gtImgTrain[0] = gt
                n_batches = int(len(training_imgs_path) // BATCH_SIZE)
                for batch in range(n_batches):
                    original_path = training_imgs_path[batch * BATCH_SIZE:(batch * BATCH_SIZE + BATCH_SIZE)]
                    original_batch = get_train_images(original_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False)
                    original_batch = original_batch.reshape([BATCH_SIZE, 256, 256, 1])
                    sess.run(train_op, feed_dict={original: original_batch, groundtruth: gtImgTrain})
                if debug:
                    for batch in range(n_batches):
                        original_path = training_imgs_path[batch * BATCH_SIZE:(batch * BATCH_SIZE + BATCH_SIZE)]
                        original_batch = get_train_images(original_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False)
                        original_batch = original_batch.reshape([BATCH_SIZE, 256, 256, 1])

                        # print('original_batch shape final:', original_batch.shape)

                        # run the training step
                        _p_loss = sess.run(pixel_loss, feed_dict={original: original_batch, groundtruth: gtImgTrain})
                        # add text file to add  mode(validation/training), epoch#, site#, batch#, _p_loss) ------------------------------
                        file.write('Train[Epoch#: %d, Site#: %d, Batch#: %d, _p_loss: %d]\n' % (epoch, site, batch, _p_loss))
                        print('Train[Epoch#: %d, Site#: %d, Batch#: %d, _p_loss: %d]' % (epoch, site, batch, _p_loss))
                print('Time taken per site: %s' %(datetime.now() - start_time_site))
                file.close()
            for site in range(numValSite):
                file = open(path + fileName, 'a')
                start_time_validation = datetime.now()
                groundtruth_val_imgs_path = list_images(validationPath+valFolders[site] + '/gt/')
                validation_imgs_path = list_images(validationPath+valFolders[site])
                np.random.shuffle(validation_imgs_path)
                gt = get_train_images(groundtruth_val_imgs_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False)
                gtImgVal = np.zeros(GROUNDTRUTH_SHAPE_OR)
                gtImgVal[0] = gt
                val_batches = int(len(validation_imgs_path) // BATCH_SIZE)
                val_pixel_acc = 0
                for batch in range(val_batches):
                    val_original_path = validation_imgs_path[batch * BATCH_SIZE:(batch * BATCH_SIZE + BATCH_SIZE)]
                    val_original_batch = get_train_images(val_original_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False)
                    val_original_batch = val_original_batch.reshape([BATCH_SIZE, 256, 256, 1])

                    val_pixel = sess.run(pixel_loss, feed_dict={original: val_original_batch, groundtruth: gtImgVal})
                    file.write('Validation[Epoch#: %d, Site#: %d, Batch#: %d, _p_loss: %d]\n' % (epoch, site, batch, val_pixel))
                    val_pixel_acc = val_pixel_acc + val_pixel
                print('Time taken per validation site: %s' % (datetime.now() - start_time_validation))
                val_loss = val_pixel_acc/val_batches
                file.write('ValidationAcc[Epoch#: %d, Site#: %d, Batch#: %d, val_loss: %d]\n' % (epoch, site, batch, val_loss))
                print('ValidationAcc[Epoch#: %d, Site#: %d, Batch#: %d, _p_loss: %d]' % (epoch, site, batch, val_loss))
                file.close()
            print('------------------------------------------------------------------------------')
            print('Time taken per epoc: %s' % (datetime.now() - start_time_epoc))
            saver.save(sess, save_path_epoc)
        saver.save(sess, save_path)
        print('Done training!')
        print('Total Time taken (training): %s' % (datetime.now() - start_time))
        file.close()
Exemplo n.º 12
0
def train(ssim_weight,
          original_imgs_path_name,
          source_a_imgs_path,
          source_b_imgs_path_name,
          encoder_path,
          save_path,
          model_pre_path,
          debug=False,
          logging_period=100):
    if debug:
        from datetime import datetime
        start_time = datetime.now()

    # num_imgs = len(source_a_imgs_path)
    num_imgs = 10000
    source_a_imgs_path = source_a_imgs_path[:num_imgs]
    mod = num_imgs % BATCH_SIZE

    print('Train images number %d.\n' % num_imgs)
    print('Train images samples %s.\n' % str(num_imgs / BATCH_SIZE))

    if mod > 0:
        print('Train set has been trimmed %d samples...\n' % mod)
        source_a_imgs_path = source_a_imgs_path[:-mod]

    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR
    INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        original = tf.placeholder(tf.float32,
                                  shape=INPUT_SHAPE_OR,
                                  name='original')
        source_a = tf.placeholder(tf.float32,
                                  shape=INPUT_SHAPE,
                                  name='source_a')
        source_b = tf.placeholder(tf.float32,
                                  shape=INPUT_SHAPE,
                                  name='source_b')

        print('source:', source_a.shape)

        # create the style transfer net
        stn = StyleTransferNet(encoder_path, model_pre_path)

        # pass content and style to the stn, getting the generated_img, fused image
        generated_img = stn.transform(source_a, source_b)

        # # get the target feature maps which is the output of AdaIN
        # target_features = stn.target_features

        pixel_loss = tf.reduce_sum(
            tf.reduce_mean(tf.square(original - generated_img), axis=[1, 2]))
        pixel_loss = pixel_loss / (HEIGHT * WIDTH)

        # compute the SSIM loss
        ssim_loss = 1 - SSIM.tf_ssim(original, generated_img)

        # compute the total loss
        loss = pixel_loss + ssim_weight * ssim_loss

        # Training step
        train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)

        # ** Start Training **
        step = 0
        count_loss = 0
        n_batches = int(len(source_a_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            print(
                '\nElapsed time for preprocessing before actually train the model: %s'
                % elapsed_time)
            print('Now begin to train the model...\n')
            start_time = datetime.now()

        Loss_all = [i for i in range(EPOCHS * n_batches)]
        for epoch in range(EPOCHS):

            np.random.shuffle(source_a_imgs_path)

            for batch in range(n_batches):
                # retrive a batch of content and style images

                source_a_path = source_a_imgs_path[batch * BATCH_SIZE:(
                    batch * BATCH_SIZE + BATCH_SIZE)]
                source_a_str = source_a_path[0]
                name_f = source_a_str.find('\\')
                source_image_name = source_a_str[name_f + 1:]
                source_image_name_comm = source_image_name[2:]

                source_b_path = [source_b_imgs_path_name + source_image_name]
                original_path = [
                    original_imgs_path_name + source_image_name_comm
                ]

                original_batch = get_train_images(original_path,
                                                  crop_height=HEIGHT,
                                                  crop_width=WIDTH,
                                                  flag=False)
                source_a_batch = get_train_images(source_a_path,
                                                  crop_height=HEIGHT,
                                                  crop_width=WIDTH)
                source_b_batch = get_train_images(source_b_path,
                                                  crop_height=HEIGHT,
                                                  crop_width=WIDTH)

                original_batch = original_batch.reshape([1, 256, 256, 1])

                # run the training step
                sess.run(train_op,
                         feed_dict={
                             original: original_batch,
                             source_a: source_a_batch,
                             source_b: source_b_batch
                         })
                step += 1
                # if step % 1000 == 0:
                #     saver.save(sess, save_path, global_step=step)
                if debug:
                    is_last_step = (epoch == EPOCHS - 1) and (batch
                                                              == n_batches - 1)

                    if is_last_step or step % logging_period == 0:
                        elapsed_time = datetime.now() - start_time
                        _pixel_loss, _ssim_loss, _loss = sess.run(
                            [pixel_loss, ssim_loss, loss],
                            feed_dict={
                                original: original_batch,
                                source_a: source_a_batch,
                                source_b: source_b_batch
                            })
                        Loss_all[count_loss] = _loss
                        count_loss += 1
                        print(
                            'step: %d,  total loss: %.3f,  elapsed time: %s' %
                            (step, _loss, elapsed_time))
                        print('pixel loss: %.3f' % (_pixel_loss))
                        print('ssim loss : %.3f\n' % (_ssim_loss))
                        # print('pca or shape  : ', _pca_or.shape)
                        # print('pca gen shape : ', _pca_gen.shape)

        # ** Done Training & Save the model **
        saver.save(sess, save_path)

        iter_index = [i for i in range(count_loss)]
        plt.plot(iter_index, Loss_all[:count_loss])
        plt.show()

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % save_path)
Exemplo n.º 13
0
def train(style_weight, content_imgs_path, style_imgs_path, encoder_path, 
          model_save_path, debug=False, logging_period=100):
    if debug:
        from datetime import datetime
        start_time = datetime.now()

    # guarantee the size of content and style images to be a multiple of BATCH_SIZE
    num_imgs = min(len(content_imgs_path), len(style_imgs_path))
    content_imgs_path = content_imgs_path[:num_imgs]
    style_imgs_path   = style_imgs_path[:num_imgs]
    mod = num_imgs % BATCH_SIZE
    if mod > 0:
        print('Train set has been trimmed %d samples...\n' % mod)
        content_imgs_path = content_imgs_path[:-mod]
        style_imgs_path   = style_imgs_path[:-mod]

    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content')
        style   = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style')

        # create the style transfer net
        stn = StyleTransferNet(encoder_path)

        # pass content and style to the stn, getting the generated_img
        generated_img = stn.transform(content, style)

        # get the target feature maps which is the output of AdaIN
        target_features = stn.target_features

        # pass the generated_img to the encoder, and use the output compute loss
        generated_img = tf.reverse(generated_img, axis=[-1])  # switch RGB to BGR
        generated_img = stn.encoder.preprocess(generated_img) # preprocess image
        enc_gen, enc_gen_layers = stn.encoder.encode(generated_img)

        # compute the content loss
        # content_loss = fft_loss(enc_gen, target_features)
        content_loss = tf.reduce_sum(tf.reduce_mean(tf.square(enc_gen - target_features), axis=[1, 2]))

        # compute the style loss
        style_layer_loss = []
        for layer in STYLE_LAYERS:
            
            enc_style_feat = stn.encoded_style_layers[layer]
            enc_gen_feat   = enc_gen_layers[layer]

            meanS, varS = tf.nn.moments(enc_style_feat, [1, 2])
            meanG, varG = tf.nn.moments(enc_gen_feat,   [1, 2])
            # fft_pred = tf.abs(tf.spectral.rfft2d(tf.transpose(enc_gen_feat, (0, 3, 1, 2))))
            # fft_true = tf.abs(tf.spectral.rfft2d(tf.transpose(enc_style_feat, (0, 3, 1, 2))))
            # meanS, varS = tf.nn.moments(fft_pred, [2, 3])
            # meanG, varG = tf.nn.moments(fft_true,   [2, 3])

            sigmaS = tf.sqrt(varS + EPSILON)
            sigmaG = tf.sqrt(varG + EPSILON)

            l2_mean  = tf.reduce_sum(tf.square(meanG - meanS))
            l2_sigma = tf.reduce_sum(tf.square(sigmaG - sigmaS))

            style_layer_loss.append(l2_mean + l2_sigma)

        style_loss = tf.reduce_sum(style_layer_loss)

        # compute the total loss
        loss = content_loss + style_weight * style_loss

        # Training step
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.inverse_time_decay(LEARNING_RATE, global_step, DECAY_STEPS, LR_DECAY_RATE)
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

        sess.run(tf.global_variables_initializer())

        # saver
        saver = tf.train.Saver(max_to_keep=10)

        ###### Start Training ######
        step = 0
        n_batches = int(len(content_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            start_time = datetime.now()
            print('\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time)
            print('Now begin to train the model...\n')

        try:
            for epoch in range(EPOCHS):

                np.random.shuffle(content_imgs_path)
                np.random.shuffle(style_imgs_path)

                for batch in range(n_batches):
                    # retrive a batch of content and style images
                    content_batch_path = content_imgs_path[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)]
                    style_batch_path   = style_imgs_path[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)]

                    content_batch = get_train_images(content_batch_path, crop_height=HEIGHT, crop_width=WIDTH)
                    style_batch   = get_train_images(style_batch_path,   crop_height=HEIGHT, crop_width=WIDTH)

                    # run the training step
                    sess.run(train_op, feed_dict={content: content_batch, style: style_batch})

                    step += 1

                    if step % 1000 == 0:
                        saver.save(sess, model_save_path, global_step=step, write_meta_graph=False)

                    if debug:
                        is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1)

                        if is_last_step or step == 1 or step % logging_period == 0:
                            elapsed_time = datetime.now() - start_time
                            _content_loss, _style_loss, _loss = sess.run([content_loss, style_loss, loss], 
                                feed_dict={content: content_batch, style: style_batch})

                            print('step: %d,  total loss: %.3f,  elapsed time: %s' % (step, _loss, elapsed_time))
                            print('content loss: %.3f' % (_content_loss))
                            print('style loss  : %.3f,  weighted style loss: %.3f\n' % (_style_loss, style_weight * _style_loss))
        except Exception as ex:
            saver.save(sess, model_save_path, global_step=step)
            print('\nSomething wrong happens! Current model is saved to <%s>' % tmp_save_path)
            print('Error message: %s' % str(ex))

        ###### Done Training & Save the model ######
        saver.save(sess, model_save_path)

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % model_save_path)
Exemplo n.º 14
0
def _handler_l1(ir_path,
                vis_path,
                model_path,
                model_pre_path,
                ssim_weight,
                index,
                output_path=None):
    ir_img = get_train_images(ir_path, True, flag=False)
    vis_img, Cr, Cb = get_train_images(vis_path, False, flag=False)
    dimension = ir_img.shape

    ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]])
    vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]])

    #ir_img = np.transpose(ir_img, (0, 2, 1, 3))
    #vis_img = np.transpose(vis_img, (0, 2, 1, 3))

    print('img shape final:', ir_img.shape)

    with tf.Graph().as_default(), tf.Session() as sess:

        # build the dataflow graph
        infrared_field = tf.placeholder(tf.float32,
                                        shape=ir_img.shape,
                                        name='content')
        visible_field = tf.placeholder(tf.float32,
                                       shape=ir_img.shape,
                                       name='style')

        dfn = DenseFuseNet(model_pre_path)

        enc_ir1, enc_ir2, enc_ir3 = dfn.transform_encoder(infrared_field)
        enc_vis1, enc_vis2, enc_vis3 = dfn.transform_encoder(visible_field)

        temp_enc = tf.concat([enc_ir1, enc_ir2], 3)
        temp_out = tf.concat([temp_enc, enc_ir3], 3)

        target = tf.placeholder(tf.float32,
                                shape=temp_out.shape,
                                name='target')

        output_image = dfn.transform_decoder(target)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        enc_ir1, enc_ir2, enc_ir3, enc_vis1, enc_vis2, enc_vis3 = sess.run(
            [enc_ir1, enc_ir2, enc_ir3, enc_vis1, enc_vis2, enc_vis3],
            feed_dict={
                infrared_field: ir_img,
                visible_field: vis_img
            })
        target_feature1_l1 = L1_norm(enc_ir1, enc_vis1)
        target_feature2_l1 = L1_norm(enc_ir2, enc_vis2)
        target_feature3_l1 = L1_norm(enc_ir3, enc_vis3)

        temp_l1 = tf.concat([target_feature1_l1, target_feature2_l1], 3)
        out_l1 = tf.concat([temp_l1, target_feature3_l1], 3)

        out_l1 = out_l1.eval()

        output = sess.run(output_image, feed_dict={target: out_l1})
        output = output.squeeze()
        result = np.dstack([output, Cr, Cb])
        result = cv2.cvtColor(result, cv2.COLOR_YCrCb2BGR)
        cv2.imwrite((output_path + str(index) +
                     '_multiScale_densefuse_l1_MRSPECT_ssim' +
                     str(ssim_weight) + '.jpg'), result)
Exemplo n.º 15
0
    proc.start()

    out_queue = run(gpuids, q)
    out_dict = {}
    while out_queue.qsize() > 0:
        item = out_queue.get()
        out_dict[item['image_name']] = item['embedding']

    with open("data/train_embeddings.p", "wb") as file:
        pickle.dump(out_dict, file)

    q.put(None)
    proc.join()


train_images = get_train_images()


def calculate_distance_list(image_i):
    embedding_i = embeddings[image_i]
    distance_list = np.empty(shape=(num_train_samples,), dtype=np.float32)
    for j, image_j in enumerate(train_images):
        embedding_j = embeddings[image_j]
        dist = np.square(np.linalg.norm(embedding_i - embedding_j))
        distance_list[j] = dist
    return distance_list


if __name__ == '__main__':
    print('creating train embeddings')
    create_train_embeddings()
Exemplo n.º 16
0
def train_recons(original_imgs_path,
                 save_path,
                 model_pre_path,
                 EPOCHES_set,
                 BATCH_SIZE,
                 debug=False,
                 logging_period=100):
    if debug:
        from datetime import datetime
        start_time = datetime.now()
    EPOCHS = EPOCHES_set
    print("EPOCHES   : ", EPOCHS)
    print("BATCH_SIZE: ", BATCH_SIZE)
    num_imgs = len(original_imgs_path)
    # num_imgs = 100
    original_imgs_path = original_imgs_path[:num_imgs]
    mod = num_imgs % BATCH_SIZE

    print('Train images number %d.\n' % num_imgs)
    print('Train images samples %s.\n' % str(num_imgs / BATCH_SIZE))

    if mod > 0:
        print('Train set has been trimmed %d samples...\n' % mod)
        original_imgs_path = original_imgs_path[:-mod]

    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR
    INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        original = tf.placeholder(tf.float32,
                                  shape=INPUT_SHAPE_OR,
                                  name='original')
        source = original

        print('source  :', source.shape)
        print('original:', original.shape)

        # create the deepfuse net (encoder and decoder)
        dfn = DeepFuseNet(model_pre_path)

        generated_img = dfn.transform_recons(source)

        print('generate:', generated_img.shape)

        ssim_loss = SSIM_LOSS(original, generated_img)

        loss = 1 - ssim_loss
        train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)

        # ** Start Training **
        step = 0
        count_loss = 0
        n_batches = int(len(original_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            print(
                '\nElapsed time for preprocessing before actually train the model: %s'
                % elapsed_time)
            print('Now begin to train the model...\n')
            start_time = datetime.now()

        Loss_all = [i for i in range(EPOCHS * n_batches)]
        for epoch in range(EPOCHS):

            np.random.shuffle(original_imgs_path)

            for batch in range(n_batches):
                # retrive a batch of content and style images

                original_path = original_imgs_path[batch * BATCH_SIZE:(
                    batch * BATCH_SIZE + BATCH_SIZE)]
                original_batch = get_train_images(original_path,
                                                  crop_height=HEIGHT,
                                                  crop_width=WIDTH,
                                                  flag=False)
                original_batch = original_batch.reshape(
                    [BATCH_SIZE, 256, 256, 1])

                # print('original_batch shape final:', original_batch.shape)

                # run the training step
                sess.run(train_op, feed_dict={original: original_batch})
                step += 1
                # if step % 1000 == 0:
                #     saver.save(sess, save_path, global_step=step)
                if debug:
                    is_last_step = (epoch == EPOCHS - 1) and (batch
                                                              == n_batches - 1)

                    if is_last_step or step % logging_period == 0:
                        elapsed_time = datetime.now() - start_time
                        _ssim_loss, _loss = sess.run(
                            [ssim_loss, loss],
                            feed_dict={original: original_batch})
                        Loss_all[count_loss] = _loss
                        count_loss += 1
                        print(
                            'Deep fuse==>>step: %d,  total loss: %s,  elapsed time: %s'
                            % (step, _loss, elapsed_time))
                        print('ssim_loss: %s ' % (_ssim_loss))

        # ** Done Training & Save the model **
        saver.save(sess, save_path)

        loss_data = Loss_all[:count_loss]
        scio.savemat(
            'D:/project/GitHub/ImageFusion/Imagefusion_deepfuse/DeepFuseLossData.mat',
            {'loss': loss_data})

        # iter_index = [i for i in range(count_loss)]
        # plt.plot(iter_index, Loss_all[:count_loss])
        # plt.show()

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % save_path)
Exemplo n.º 17
0
def train_recons(original_imgs_path,
                 save_path,
                 model_pre_path,
                 EPOCHES_set,
                 BATCH_SIZE_set,
                 debug=False,
                 logging_period=1):
    from datetime import datetime
    if debug:
        start_time = datetime.now()
    EPOCHS = EPOCHES_set
    BATCH_SIZE = BATCH_SIZE_set
    print("EPOCHES   : ", EPOCHS)
    print("BATCH_SIZE: ", BATCH_SIZE)

    num_imgs = len(original_imgs_path)
    mod = num_imgs % BATCH_SIZE

    print('Train images number {}.'.format(num_imgs))
    print('Train images samples {}.'.format(num_imgs // BATCH_SIZE))

    if mod > 0:
        print('Train set has been trimmed {} samples...'.format(mod))
        original_imgs_path = original_imgs_path[:-mod]

    # get the traing image shape
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    # create the graph
    with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
        with tf.compat.v1.name_scope('Input'):
            original = tf.placeholder(tf.float32,
                                      shape=INPUT_SHAPE,
                                      name='original')
            source = original

        print('source :', source.shape)
        print('original :', original.shape)

        # create the deepfuse net (encoder and decoder)
        dfn = DenseFuseNet(model_pre_path)
        generated_img = dfn.transform_recons(source)
        print('generate:', generated_img.shape)

        epsilon_1 = tf.reduce_mean(tf.square(generated_img - original))
        epsilon_2 = 1 - tf.reduce_mean(
            tf.image.ssim(generated_img, original, max_val=1.0))
        total_loss = epsilon_1 + 1000 * epsilon_2

        tf.compat.v1.summary.scalar('epsilon_1', epsilon_1)
        tf.compat.v1.summary.scalar('epsilon_2', epsilon_2)
        tf.compat.v1.summary.scalar('total_loss', total_loss)

        train_op = tf.compat.v1.train.AdamOptimizer(LEARNING_RATE).minimize(
            total_loss)

        summary_op = tf.compat.v1.summary.merge_all()
        train_writer = tf.compat.v1.summary.FileWriter('./models/log',
                                                       sess.graph,
                                                       flush_secs=60)
        train_writer.add_graph(sess.graph)

        sess.run(tf.compat.v1.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.compat.v1.train.Saver(max_to_keep=20)

        # ** Start Training **
        step = 0
        n_batches = int(len(original_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            print(
                'Elapsed time for preprocessing before actually train the model: {}'
                .format(elapsed_time))
            print('Now begin to train the model...')
            start_time = datetime.now()

        Loss_1 = []
        Loss_2 = []
        Loss_all = []
        for epoch in range(EPOCHS):
            for batch in range(n_batches):
                # retrive a batch of infrared and visiable images
                original_path = original_imgs_path[batch * BATCH_SIZE:(
                    batch * BATCH_SIZE + BATCH_SIZE)]
                original_batch = get_train_images(original_path)
                # print(original_batch.shape)
                original_batch = original_batch.transpose((3, 0, 1, 2))
                # run the training step
                step += 1
                _, summary_str, _epsilon_1, _epsilon_2, _total_loss = sess.run(
                    [train_op, summary_op, epsilon_1, epsilon_2, total_loss],
                    feed_dict={original: original_batch})

                train_writer.add_summary(summary_str, step)
                Loss_1.append(_epsilon_1)
                Loss_2.append(_epsilon_2)
                Loss_all.append(_total_loss)

                if debug:
                    is_last_step = (epoch == EPOCHS - 1) and (batch
                                                              == n_batches - 1)

                    if is_last_step or step % logging_period == 0:
                        elapsed_time = datetime.now() - start_time
                        print(
                            'epoch:{:>2}/{}, step:{:>4}, total loss: {:.4f}, elapsed time: {}'
                            .format(epoch + 1, EPOCHS, step, _total_loss,
                                    elapsed_time))
                        print('epsilon_1: {}, epsilon_2: {}\n'.format(
                            _epsilon_1, _epsilon_2))

            # ** Done Training & Save the model **
            saver.save(sess, save_path, global_step=epoch + 1)

            if not os.path.exists('./models/loss/'):
                os.mkdir('./models/loss/')

            scio.savemat('./models/loss/TotalLoss_' + str(epoch + 1) + '.mat',
                         {'total_loss': Loss_all})
            scio.savemat('./models/loss/Epsilon1_' + str(epoch + 1) + '.mat',
                         {'epsilon_1': Loss_1})
            scio.savemat('./models/loss/Epsilon2_' + str(epoch + 1) + '.mat',
                         {'epsilon_2': Loss_2})

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: {}'.format(elapsed_time))
            print('Model is saved to: {}'.format(save_path))
Exemplo n.º 18
0
def train_recons(original_imgs_path,
                 validatioin_imgs_path,
                 save_path,
                 model_pre_path,
                 ssim_weight,
                 EPOCHES_set,
                 BATCH_SIZE,
                 debug=False,
                 logging_period=1):
    if debug:
        from datetime import datetime
        start_time = datetime.now()
    EPOCHS = EPOCHES_set
    print("EPOCHES   : ", EPOCHS)
    print("BATCH_SIZE: ", BATCH_SIZE)

    num_val = len(validatioin_imgs_path)
    num_imgs = len(original_imgs_path)
    # num_imgs = 100
    original_imgs_path = original_imgs_path[:num_imgs]
    mod = num_imgs % BATCH_SIZE

    print('Train images number %d.\n' % num_imgs)
    print('Train images samples %s.\n' % str(num_imgs / BATCH_SIZE))

    if mod > 0:
        print('Train set has been trimmed %d samples...\n' % mod)
        original_imgs_path = original_imgs_path[:-mod]

    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR
    INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        original = tf.placeholder(tf.float32,
                                  shape=INPUT_SHAPE_OR,
                                  name='original')
        source = original

        print('source  :', source.shape)
        print('original:', original.shape)

        # create the deepfuse net (encoder and decoder)
        dfn = DenseFuseNet(model_pre_path)
        generated_img = dfn.transform_recons(source)
        print('generate:', generated_img.shape)

        ssim_loss_value = SSIM_LOSS(original, generated_img)
        pixel_loss = tf.reduce_sum(tf.square(original - generated_img))
        pixel_loss = pixel_loss / (BATCH_SIZE * HEIGHT * WIDTH)
        ssim_loss = 1 - ssim_loss_value

        loss = ssim_weight * ssim_loss + pixel_loss
        train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)

        # ** Start Training **
        step = 0
        count_loss = 0
        n_batches = int(len(original_imgs_path) // BATCH_SIZE)
        val_batches = int(len(validatioin_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            print(
                '\nElapsed time for preprocessing before actually train the model: %s'
                % elapsed_time)
            print('Now begin to train the model...\n')
            start_time = datetime.now()

        Loss_all = [i for i in range(EPOCHS * n_batches)]
        Loss_ssim = [i for i in range(EPOCHS * n_batches)]
        Loss_pixel = [i for i in range(EPOCHS * n_batches)]
        Val_ssim_data = [i for i in range(EPOCHS * n_batches)]
        Val_pixel_data = [i for i in range(EPOCHS * n_batches)]
        for epoch in range(EPOCHS):

            np.random.shuffle(original_imgs_path)

            for batch in range(n_batches):
                # retrive a batch of content and style images

                original_path = original_imgs_path[batch * BATCH_SIZE:(
                    batch * BATCH_SIZE + BATCH_SIZE)]
                original_batch = get_train_images(original_path,
                                                  crop_height=HEIGHT,
                                                  crop_width=WIDTH,
                                                  flag=False)
                original_batch = original_batch.reshape(
                    [BATCH_SIZE, 256, 256, 1])

                # print('original_batch shape final:', original_batch.shape)

                # run the training step
                sess.run(train_op, feed_dict={original: original_batch})
                step += 1
                if debug:
                    is_last_step = (epoch == EPOCHS - 1) and (batch
                                                              == n_batches - 1)

                    if is_last_step or step % logging_period == 0:
                        elapsed_time = datetime.now() - start_time
                        _ssim_loss, _loss, _p_loss = sess.run(
                            [ssim_loss, loss, pixel_loss],
                            feed_dict={original: original_batch})
                        Loss_all[count_loss] = _loss
                        Loss_ssim[count_loss] = _ssim_loss
                        Loss_pixel[count_loss] = _p_loss
                        print(
                            'epoch: %d/%d, step: %d,  total loss: %s, elapsed time: %s'
                            % (epoch, EPOCHS, step, _loss, elapsed_time))
                        print('p_loss: %s, ssim_loss: %s ,w_ssim_loss: %s ' %
                              (_p_loss, _ssim_loss, ssim_weight * _ssim_loss))

                        # calculate the accuracy rate for 1000 images, every 100 steps
                        val_ssim_acc = 0
                        val_pixel_acc = 0
                        np.random.shuffle(validatioin_imgs_path)
                        val_start_time = datetime.now()
                        for v in range(val_batches):
                            val_original_path = validatioin_imgs_path[
                                v * BATCH_SIZE:(v * BATCH_SIZE + BATCH_SIZE)]
                            val_original_batch = get_train_images(
                                val_original_path,
                                crop_height=HEIGHT,
                                crop_width=WIDTH,
                                flag=False)
                            val_original_batch = val_original_batch.reshape(
                                [BATCH_SIZE, 256, 256, 1])
                            val_ssim, val_pixel = sess.run(
                                [ssim_loss, pixel_loss],
                                feed_dict={original: val_original_batch})
                            val_ssim_acc = val_ssim_acc + (1 - val_ssim)
                            val_pixel_acc = val_pixel_acc + val_pixel
                        Val_ssim_data[count_loss] = val_ssim_acc / val_batches
                        Val_pixel_data[
                            count_loss] = val_pixel_acc / val_batches
                        val_es_time = datetime.now() - val_start_time
                        print(
                            'validation value, SSIM: %s, Pixel: %s, elapsed time: %s'
                            % (val_ssim_acc / val_batches,
                               val_pixel_acc / val_batches, val_es_time))
                        print(
                            '------------------------------------------------------------------------------'
                        )
                        count_loss += 1

        # ** Done Training & Save the model **
        saver.save(sess, save_path)

        loss_data = Loss_all[:count_loss]
        scio.savemat(
            './models/loss/DeepDenseLossData' + str(ssim_weight) + '.mat',
            {'loss': loss_data})

        loss_ssim_data = Loss_ssim[:count_loss]
        scio.savemat(
            './models/loss/DeepDenseLossSSIMData' + str(ssim_weight) + '.mat',
            {'loss_ssim': loss_ssim_data})

        loss_pixel_data = Loss_pixel[:count_loss]
        scio.savemat(
            './models/loss/DeepDenseLossPixelData.mat' + str(ssim_weight) + '',
            {'loss_pixel': loss_pixel_data})

        validation_ssim_data = Val_ssim_data[:count_loss]
        scio.savemat(
            './models/val/Validation_ssim_Data.mat' + str(ssim_weight) + '',
            {'val_ssim': validation_ssim_data})

        validation_pixel_data = Val_pixel_data[:count_loss]
        scio.savemat(
            './models/val/Validation_pixel_Data.mat' + str(ssim_weight) + '',
            {'val_pixel': validation_pixel_data})

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % save_path)
Exemplo n.º 19
0
def _handler_mix(ir_path,
                 vis_path,
                 model_path,
                 model_pre_path,
                 ssim_weight,
                 index,
                 output_path=None):
    mix_block = []
    ir_img = get_train_images(ir_path, flag=False)
    vis_img = get_train_images(vis_path, flag=False)
    dimension = ir_img.shape
    ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]])
    vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]])
    ir_img = np.transpose(ir_img, (0, 2, 1, 3))
    vis_img = np.transpose(vis_img, (0, 2, 1, 3))

    print('img shape final:', ir_img.shape)
    with tf.Graph().as_default(), tf.Session() as sess:
        infrared_field = tf.placeholder(tf.float32,
                                        shape=ir_img.shape,
                                        name='content')
        visible_field = tf.placeholder(tf.float32,
                                       shape=vis_img.shape,
                                       name='style')

        # -----------------------------------------------

        dfn = DenseFuseNet(model_pre_path)

        #sess.run(tf.global_variables_initializer())

        enc_ir, enc_ir_res_block, enc_ir_block, enc_ir_block2 = dfn.transform_encoder(
            infrared_field)
        enc_vis, enc_vis_res_block, enc_vis_block, enc_vis_block2 = dfn.transform_encoder(
            visible_field)

        result = tf.placeholder(tf.float32, shape=enc_ir.shape, name='target')

        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        enc_ir_temp, enc_ir_res_block_temp, enc_ir_block_temp, enc_ir_block2_temp = sess.run(
            [enc_ir, enc_ir_res_block, enc_ir_block, enc_ir_block2],
            feed_dict={infrared_field: ir_img})
        enc_vis_temp, enc_vis_res_block_temp, enc_vis_block_temp, enc_vis_block2_temp = sess.run(
            [enc_vis, enc_vis_res_block, enc_vis_block, enc_vis_block2],
            feed_dict={visible_field: vis_img})

        block = L1_norm(enc_ir_block_temp, enc_vis_block_temp)
        block2 = L1_norm(enc_ir_block2_temp, enc_vis_block2_temp)

        first_first = L1_norm(enc_ir_res_block_temp[0],
                              enc_vis_res_block_temp[0])
        first_second = Strategy(enc_ir_res_block_temp[1],
                                enc_vis_res_block_temp[1])
        #first_third = L1_norm_attention(enc_ir_res_block_temp[2],feation_ir, enc_vis_res_block_temp[2],feation_vis)
        #first_four = L1_norm_attention(enc_ir_res_block_temp[3],feation_ir, enc_vis_res_block_temp[3],feation_vis)
        first_third = L1_norm(enc_ir_res_block_temp[2],
                              enc_vis_res_block_temp[2])
        first_four = Strategy(enc_ir_res_block_temp[3],
                              enc_vis_res_block_temp[3])
        first_first = tf.concat(
            [first_first, tf.to_int32(first_second, name='ToInt')], 3)
        first_first = tf.concat(
            [first_first, tf.to_int32(first_third, name='ToInt')], 3)
        first_first = tf.concat([first_first, first_four], 3)

        first = first_first

        second = L1_norm(enc_ir_res_block_temp[6], enc_vis_res_block_temp[6])
        third = L1_norm(enc_ir_res_block_temp[9], enc_vis_res_block_temp[9])

        feature = 1 * first + 0.1 * second + 0.1 * third

        #---------------------------------------------------------
        # block=Strategy(enc_ir_block_temp,enc_vis_block_temp)
        # block2=L1_norm(enc_ir_block2_temp,enc_vis_block2_temp)
        #---------------------------------------------------------

        feature = feature.eval()

        output_image = dfn.transform_decoder(result, block, block2)

        # output = dfn.transform_decoder(feature)
        # print(type(feature))
        # output = sess.run(output_image, feed_dict={result: feature,enc_res_block:block,enc_res_block2:block2})
        output = sess.run(output_image, feed_dict={result: feature})

        save_images(ir_path,
                    output,
                    output_path,
                    prefix='fused' + str(index),
                    suffix='_mix_' + str(ssim_weight))
Exemplo n.º 20
0
def _handler_mix_a(ir_path,
                   vis_path,
                   model_path,
                   model_pre_path,
                   model_path_a,
                   model_pre_path_a,
                   ssim_weight,
                   index,
                   output_path=None):
    ir_img = get_train_images(ir_path, flag=False)
    vis_img = get_train_images(vis_path, flag=False)
    dimension = ir_img.shape
    ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]])
    vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]])
    ir_img = np.transpose(ir_img, (0, 2, 1, 3))
    vis_img = np.transpose(vis_img, (0, 2, 1, 3))

    g2 = tf.Graph()  # 加载到Session 2的graph

    sess2 = tf.Session(graph=g2)  # Session2

    with sess2.as_default():  # 1
        with g2.as_default(), tf.Session() as sess:
            infrared_field = tf.placeholder(tf.float32,
                                            shape=ir_img.shape,
                                            name='content')
            visible_field = tf.placeholder(tf.float32,
                                           shape=vis_img.shape,
                                           name='style')

            dfn = DenseFuseNet(model_pre_path)

            # sess.run(tf.global_variables_initializer())

            enc_ir, enc_ir_res_block, enc_ir_block, enc_ir_block2 = dfn.transform_encoder(
                infrared_field)
            enc_vis, enc_vis_res_block, enc_vis_block, enc_vis_block2 = dfn.transform_encoder(
                visible_field)

            result = tf.placeholder(tf.float32,
                                    shape=enc_ir.shape,
                                    name='target')

            saver = tf.train.Saver()
            saver.restore(sess, model_path)

            # ------------------------attention------------------------------------------------------
            #feature_a,feature_b=_get_attention(ir_path,vis_path,model_path_a,model_pre_path_a)
            #print("______+++________")
            #print(feature_a[0].shape)
            # ------------------------attention------------------------------------------------------

            enc_ir_temp, enc_ir_res_block_temp, enc_ir_block_temp, enc_ir_block2_temp = sess.run(
                [enc_ir, enc_ir_res_block, enc_ir_block, enc_ir_block2],
                feed_dict={infrared_field: ir_img})
            enc_vis_temp, enc_vis_res_block_temp, enc_vis_block_temp, enc_vis_block2_temp = sess.run(
                [enc_vis, enc_vis_res_block, enc_vis_block, enc_vis_block2],
                feed_dict={visible_field: vis_img})

            # ------------------------------------------------------------------------------------------------------------
            #------------------------------------------------------------------------------------------------------------
            block = 0.8 * enc_vis_block_temp + 0.2 * enc_ir_block_temp
            block2 = 0.4 * enc_ir_block2_temp + 0.6 * enc_vis_block2_temp

            #first_first = Strategy(enc_ir_res_block_temp[0], enc_vis_res_block_temp[0])
            #first_first = L1_norm(enc_ir_res_block_temp[0], enc_vis_res_block_temp[0])
            #first_second = Strategy(enc_ir_res_block_temp[1], enc_vis_res_block_temp[1])
            #first_second = L1_norm(enc_ir_res_block_temp[1], enc_vis_res_block_temp[1])
            #first_third = Strategy(enc_ir_res_block_temp[2], enc_vis_res_block_temp[2])
            #first_third = L1_norm_attention(enc_ir_res_block_temp[2],feature_a, enc_vis_res_block_temp[2],feature_b)
            #first_four = Strategy(enc_ir_res_block_temp[3], enc_vis_res_block_temp[3])
            #first_four = L1_norm_attention(enc_ir_res_block_temp[3],feature_a, enc_vis_res_block_temp[3],feature_b)
            #first_first = tf.concat([first_first, tf.to_int32(first_second, name='ToInt')], 3)
            #first_first = tf.concat([first_first, tf.to_int32(first_third, name='ToInt')], 3)
            #first_first = tf.concat([first_first, first_four], 3)

            #first = first_first

            first = Strategy(enc_ir_res_block_temp[3],
                             enc_vis_res_block_temp[3])
            second = Strategy(enc_ir_res_block_temp[6],
                              enc_vis_res_block_temp[6])
            third = Strategy(enc_ir_res_block_temp[9],
                             enc_vis_res_block_temp[9])
            # ------------------------------------------------------------------------------------------------------------
            # ------------------------------------------------------------------------------------------------------------

            feature = 1 * first + 1 * second + 1 * third

            # ---------------------------------------------------------
            # block=Strategy(enc_ir_block_temp,enc_vis_block_temp)
            # block2=L1_norm(enc_ir_block2_temp,enc_vis_block2_temp)
            # ---------------------------------------------------------

            #feature = feature.eval()

            # --------------将特征图压成单通道----------------------------------
            #feature_map_vis_out = sess.run(tf.reduce_sum(feature_a[0], 3, keep_dims=True))
            #feature_map_ir_out = sess.run(tf.reduce_sum(feature_b[0],3, keep_dims=True))
            # ------------------------------------------------------------------

            output_image = dfn.transform_decoder(result, block, block2)

            # output = dfn.transform_decoder(feature)
            # print(type(feature))
            # output = sess.run(output_image, feed_dict={result: feature,enc_res_block:block,enc_res_block2:block2})
            output = sess.run(output_image, feed_dict={result: feature})

            save_images(ir_path,
                        output,
                        output_path,
                        prefix='fused' + str(index),
                        suffix='_mix_' + str(ssim_weight))
Exemplo n.º 21
0
        print('\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time)

        print('Now begin to train the model...\n')
        start_time = datetime.now()

        for epoch in range(EPOCHS):

            np.random.shuffle(content_images)
            np.random.shuffle(style_images)

            for batch in range(n_batches):
                # retrive a batch of content and style images
                content_batch_path = content_images[batch * BATCH_SIZE:(batch * BATCH_SIZE + BATCH_SIZE)]
                style_batch_path = style_images[batch * BATCH_SIZE:(batch * BATCH_SIZE + BATCH_SIZE)]

                content_batch = utils.get_train_images(content_batch_path, crop_height=HEIGHT, crop_width=WIDTH)
                style_batch = utils.get_train_images(style_batch_path, crop_height=HEIGHT, crop_width=WIDTH)

                # run the training step
                sess.run(train_op, feed_dict={content_input: content_batch, style_input: style_batch})

                if step % 100 == 0:

                    _content_loss, _style_loss, _loss = sess.run([content_loss, style_loss, loss],
                                                                 feed_dict={content_input: content_batch,
                                                                            style_input: style_batch})

                    elapsed_time = datetime.now() - start_time
                    print('step: %d,  total loss: %.3f, elapsed time: %s' % (step, _loss,elapsed_time))
                    print('content loss: %.3f' % (_content_loss))
                    print('style loss  : %.3f,  weighted style loss: %.3f\n' % (
def train_recons(original_imgs_path, validatioin_imgs_path, save_path, model_pre_path, ssim_weight, EPOCHES_set, BATCH_SIZE, debug=False, logging_period=1):
    if debug:
        from datetime import datetime
        start_time = datetime.now()
    EPOCHS = EPOCHES_set
    print("EPOCHES   : ", EPOCHS)             #EPOCHS = 4           遍历整个数据集的次数,训练网络一共要执行n*4次
    print("BATCH_SIZE: ", BATCH_SIZE)         #BATCH_SIZE = 2       每个Batch有2个样本,共n/2个Batch,每处理两个样本模型权重就更新

    num_val = len(validatioin_imgs_path)        #测试集样本个数
    num_imgs = len(original_imgs_path)          #训练集样本个数
    # num_imgs = 100
    original_imgs_path = original_imgs_path[:num_imgs]                          #迷惑行为,自己赋给自己
    mod = num_imgs % BATCH_SIZE                 #Batch个数

    print('Train images number %d.\n' % num_imgs)
    print('Train images samples %s.\n' % str(num_imgs / BATCH_SIZE))

    if mod > 0:
        print('Train set has been trimmed %d samples...\n' % mod)
        original_imgs_path = original_imgs_path[:-mod]                          #original_imags_path 数组移除最后两个

    # get the traing image shape
    #训练图像的长宽及通道数    255,255,1
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)                         #定义元组,意义不明

    HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR
    INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR)             #OR是什么意思,意义不明

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        original = tf.placeholder(tf.float32, shape=INPUT_SHAPE_OR, name='original')
        #神经网络构建graph的时候在模型中的占位,只分配必要的内存,运行模型时通过feed_dict()向占位符喂入数据
        #第一个参数,数据类型,常用tf.float32,tf.float64
        #第二个参数,数据形状,矩阵形状,图像的长宽及通道数
        #第三个参数,名称
        #返回Tensor类型
        source = original                                               #迷惑行为,意义不明

        print('source  :', source.shape)
        print('original:', original.shape)

        # create the deepfuse net (encoder and decoder)
        #创建深度学习网络
        dfn = DenseFuseNet(model_pre_path)                              #这里的model_pre_path是自己设置的模型参数,默认是None,若不为None则起始训练的参数为设置的文件
        generated_img = dfn.transform_recons(source)                    #输出图像
        print('generate:', generated_img.shape)

        #########################################################################################
        # COST FUNCTION 部分
        ssim_loss_value = SSIM_LOSS(original, generated_img)                #计算SSIM
        pixel_loss = tf.reduce_sum(tf.square(original - generated_img))
        pixel_loss = pixel_loss/(BATCH_SIZE*HEIGHT*WIDTH)                   #计算pixel loss
        ssim_loss = 1 - ssim_loss_value                                     #SSIM loss数值

        loss = ssim_weight*ssim_loss + pixel_loss                           #整体loss
        #train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)     #自适应矩估计(梯度下降的一种方法)
        train_op = tf.train.AdamOptimizer(LEARNING_RATE_2).minimize(loss)  # 自适应矩估计(梯度下降的一种方法)
        ##########################################################################################

        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)

        # ** Start Training **
        step = 0
        count_loss = 0
        n_batches = int(len(original_imgs_path) // BATCH_SIZE)
        val_batches = int(len(validatioin_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            print('\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time)
            print('Now begin to train the model...\n')
            start_time = datetime.now()

        Loss_all = [i for i in range(EPOCHS * n_batches)]
        Loss_ssim = [i for i in range(EPOCHS * n_batches)]
        Loss_pixel = [i for i in range(EPOCHS * n_batches)]
        Val_ssim_data = [i for i in range(EPOCHS * n_batches)]
        Val_pixel_data = [i for i in range(EPOCHS * n_batches)]
        for epoch in range(EPOCHS):

            np.random.shuffle(original_imgs_path)

            for batch in range(n_batches):
                # retrive a batch of content and style images

                original_path = original_imgs_path[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)]
                original_batch = get_train_images(original_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False)
                original_batch = original_batch.reshape([BATCH_SIZE, 256, 256, 1])

                # print('original_batch shape final:', original_batch.shape)

                # run the training step
                sess.run(train_op, feed_dict={original: original_batch})
                step += 1
                if debug:
                    is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1)

                    if is_last_step or step % logging_period == 0:
                        elapsed_time = datetime.now() - start_time
                        _ssim_loss, _loss, _p_loss = sess.run([ssim_loss, loss, pixel_loss], feed_dict={original: original_batch})
                        Loss_all[count_loss] = _loss
                        Loss_ssim[count_loss] = _ssim_loss
                        Loss_pixel[count_loss] = _p_loss
                        print('epoch: %d/%d, step: %d,  total loss: %s, elapsed time: %s' % (epoch, EPOCHS, step, _loss, elapsed_time))
                        print('p_loss: %s, ssim_loss: %s ,w_ssim_loss: %s ' % (_p_loss, _ssim_loss, ssim_weight * _ssim_loss))

                        # calculate the accuracy rate for 1000 images, every 100 steps
                        val_ssim_acc = 0
                        val_pixel_acc = 0
                        np.random.shuffle(validatioin_imgs_path)
                        val_start_time = datetime.now()
                        for v in range(val_batches):
                            val_original_path = validatioin_imgs_path[v * BATCH_SIZE:(v * BATCH_SIZE + BATCH_SIZE)]
                            val_original_batch = get_train_images(val_original_path, crop_height=HEIGHT, crop_width=WIDTH,flag=False)
                            val_original_batch = val_original_batch.reshape([BATCH_SIZE, 256, 256, 1])
                            val_ssim, val_pixel = sess.run([ssim_loss, pixel_loss], feed_dict={original: val_original_batch})
                            val_ssim_acc = val_ssim_acc + (1 - val_ssim)
                            val_pixel_acc = val_pixel_acc + val_pixel
                        Val_ssim_data[count_loss] = val_ssim_acc/val_batches
                        Val_pixel_data[count_loss] = val_pixel_acc / val_batches
                        val_es_time = datetime.now() - val_start_time
                        print('validation value, SSIM: %s, Pixel: %s, elapsed time: %s' % (val_ssim_acc/val_batches, val_pixel_acc / val_batches, val_es_time))
                        print('------------------------------------------------------------------------------')
                        count_loss += 1


        # ** Done Training & Save the model **
        saver.save(sess, save_path)
#----------------------------------------------------------------------------------------------------------------
        loss_data = Loss_all[:count_loss]
        scio.savemat('/data/ljy/1-Project-Go/01-06-upsampling/models/loss/DeepDenseLossData' + str(ssim_weight) + '.mat',
                     {'loss': loss_data})

        loss_ssim_data = Loss_ssim[:count_loss]
        scio.savemat('/data/ljy/1-Project-Go/01-06-upsampling/models/loss/DeepDenseLossSSIMData' + str(
            ssim_weight) + '.mat', {'loss_ssim': loss_ssim_data})

        loss_pixel_data = Loss_pixel[:count_loss]
        scio.savemat('/data/ljy/1-Project-Go/01-06-upsampling/models/loss/DeepDenseLossPixelData.mat' + str(
            ssim_weight) + '', {'loss_pixel': loss_pixel_data})

        validation_ssim_data = Val_ssim_data[:count_loss]
        scio.savemat('/data/ljy/1-Project-Go/01-06-upsampling/models/val/Validation_ssim_Data.mat' + str(
            ssim_weight) + '', {'val_ssim': validation_ssim_data})

        validation_pixel_data = Val_pixel_data[:count_loss]
        scio.savemat('/data/ljy/1-Project-Go/01-06-upsampling/models/val/Validation_pixel_Data.mat' + str(
            ssim_weight) + '', {'val_pixel': validation_pixel_data})
#----------------------------------------------------------------------------------------------------
        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % save_path)
Exemplo n.º 23
0
def train(opt):
    '''
    Input:
        opt : optins for training model
    '''
    content_img_list = utils.list_images(opt.content_img_dir)
    style_img_list = utils.list_images(opt.style_img_dir)

    #import pdb; pdb.set_trace()

    assert(content_img_list) # ensure not empty 
    assert(style_img_list)
    
    with tf.Graph().as_default(), tf.Session() as sess:
        content_img = tf.placeholder(tf.float32, shape=(opt.batch_size, opt.img_size, opt.img_size, 3), name='content_img')
        style_img   = tf.placeholder(tf.float32, shape=(opt.batch_size, opt.img_size, opt.img_size, 3), name='style_img')

        model = Model(opt.checkpoint_encoder)
        
        # Encode Image 
        generated_img = model.transform(content_img, style_img)
        generated_img = tf.reverse(generated_img, axis=[-1])  
        generated_img = model.encoder.preprocess(generated_img)  
        enc_gen, enc_gen_layers = model.encoder.encode(generated_img)

        target_features = model.target_features

        # Content Loss 
        content_loss = tf.reduce_sum(tf.reduce_mean(tf.square(enc_gen - target_features), axis=[1, 2]))

        # Style Loss 
        style_layer_loss = []
        style_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1']
        for layer in style_layers:
            enc_style_feat = model.encoded_style_layers[layer]
            enc_gen_feat   = enc_gen_layers[layer]

            meanS, varS = tf.nn.moments(enc_style_feat, [1, 2])
            meanG, varG = tf.nn.moments(enc_gen_feat,   [1, 2])

            sigmaS = tf.sqrt(varS + opt.epsilon)
            sigmaG = tf.sqrt(varG + opt.epsilon)

            l2_mean  = tf.reduce_sum(tf.square(meanG - meanS))
            l2_sigma = tf.reduce_sum(tf.square(sigmaG - sigmaS))

            style_layer_loss.append(l2_mean + l2_sigma)

        style_loss = tf.reduce_sum(style_layer_loss)

        # Total loss 
        loss = opt.content_weight * content_loss + opt.style_weight * style_loss

        # Train 
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.inverse_time_decay(opt.lr, global_step, opt.lr_decay_step, opt.lr_decay)
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

        sess.run(tf.global_variables_initializer())

        # saver
        saver = tf.train.Saver(max_to_keep=10)

        step = 0
        n_batches = int(len(content_img_list) // opt.batch_size)

        try:
            for epoch in range(opt.epoch):

                np.random.shuffle(content_img_list)
                np.random.shuffle(style_img_list)

                for batch in range(n_batches):
                    print('iteration {}'.format(step))

                    # retrive a batch of content and style images
                    content_batch_path = content_img_list[batch*opt.batch_size:(batch*opt.batch_size + opt.batch_size)]
                    style_batch_path   = style_img_list[batch*opt.batch_size:(batch*opt.batch_size + opt.batch_size)]

                    content_batch = utils.get_train_images(content_batch_path, crop_height=opt.img_size, crop_width=opt.img_size)
                    style_batch   = utils.get_train_images(style_batch_path,   crop_height=opt.img_size, crop_width=opt.img_size)

                    # run the training step
                    sess.run(train_op, feed_dict={
                        content_img: content_batch, 
                        style_img: style_batch
                        })

                    step += 1

                    if step % 1000 == 0:
                        saver.save(sess, opt.checkpoint_save_dir, global_step=step, write_meta_graph=False)

        except Exception as ex:
            saver.save(sess, opt.checkpoint_save_dir, global_step=step)
            print('Error message: %s' % str(ex))

        # Finish 
        saver.save(sess, opt.checkpoint_save_dir)