Example #1
0
def TestNetwork():
	# define and load depth_prediction network
	depth_net = FCRN(batch_size)
	resume_file_depth = 'checkpoint.pth.tar'
	checkpoint = torch.load(resume_file_depth)
	depth_net.load_state_dict(checkpoint['state_dict'])
	depth_net.cuda()

	# define and load q_learning network
	online_net = QNetwork()
	resume_file_online = 'online_with_noise.pth.tar'
	checkpoint_online = torch.load(resume_file_online)
	online_net.load_state_dict(checkpoint_online['state_dict'])
	online_net.cuda()
	rospy.sleep(1.)

	# Initialize the World and variables
	env = RealWorld()
	print('Environment initialized')
	episode = 0

	# start training
	rate = rospy.Rate(3)

	with torch.no_grad():
		while not rospy.is_shutdown():
			episode += 1
			t = 0

			rgb_img_t1 = env.GetRGBImageObservation()
			rgb_img_t1 = rgb_img_t1[np.newaxis, :]
			rgb_img_t1 = torch.from_numpy(rgb_img_t1)
			rgb_img_t1 = rgb_img_t1.permute(0, 3, 1, 2)
			rgb_img_t1 = Variable(rgb_img_t1.type(dtype))
			depth_img_t1 = depth_net(rgb_img_t1) - 0.2
			depth_img_t1 = torch.squeeze(depth_img_t1, 1)
			depth_imgs_t1 = torch.stack((depth_img_t1, depth_img_t1, depth_img_t1, depth_img_t1), dim=1)

			while not rospy.is_shutdown():
				rgb_img_t1 = env.GetRGBImageObservation()
				# cv2.imwrite('rgb_depth.png', rgb_img_t1)
				rgb_img_t1 = rgb_img_t1[np.newaxis, :]
				rgb_img_t1 = torch.from_numpy(rgb_img_t1)
				rgb_img_t1 = rgb_img_t1.permute(0, 3, 1, 2)
				rgb_img_t1 = Variable(rgb_img_t1.type(dtype))
				depth_img_t1 = depth_net(rgb_img_t1) - 0.2
				depth_imgs_t1 = torch.cat((depth_img_t1, depth_imgs_t1[:, :(IMAGE_HIST - 1), :, :]), 1)
				depth_imgs_t1_cuda = Variable(depth_imgs_t1.type(dtype))
				predicted_depth = depth_img_t1[0].data.squeeze().cpu().numpy().astype(np.float32)
				cv2.imwrite('predicted_depth.png', predicted_depth * 50)


				Q_value_list = online_net(depth_imgs_t1_cuda)
				Q_value_list = Q_value_list[0]
				Q_value, action = torch.max(Q_value_list, 0)
				env.Control(action)
				t += 1
				rate.sleep()
Example #2
0
data_path = 'test.mat'
dtype = torch.cuda.FloatTensor

batch_size = 1
resume_from_file = True
Threshold_1_25 = 0
Threshold_1_25_2 = 0
Threshold_1_25_3 = 0
RMSE_linear = 0.0
RMSE_log = 0.0
RMSE_log_scale_invariant = 0.0
ARD = 0.0
SRD = 0.0

model = FCRN(batch_size)
model = model.cuda()
loss_fn = torch.nn.MSELoss().cuda()

resume_file = 'checkpoint.pth.tar'

if resume_from_file:
    if os.path.isfile(resume_file):
        print("=> loading checkpoint '{}'".format(resume_file))
        checkpoint = torch.load(resume_file)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            resume_file, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(resume_file))
def main():
    batch_size = args.batch_size
    data_path = 'nyu_depth_v2_labeled.mat'
    learning_rate = args.lr  #1.0e-4 #1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = args.epochs
    step_size = args.step_size
    step_gamma = args.step_gamma
    resume_from_file = False
    isDataAug = args.data_aug
    max_depth = 1000

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=True)
    print(train_loader)

    # 2.set the model
    print("Set the model......")
    model = FCRN(batch_size)
    resnet = torchvision.models.resnet50()

    # 加载训练到一半的模型
    # resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    # print("resnet50 params loaded.")

    # model.load_state_dict(load_weights(model, weights_file, dtype))

    model = model.cuda()

    # 3.Loss
    # loss_fn = torch.nn.MSELoss().cuda()
    if args.loss_type == "berhu":
        loss_fn = criteria.berHuLoss().cuda()
        print("berhu loss_fn set.")
    elif args.loss_type == "L1":
        loss_fn = criteria.MaskedL1Loss().cuda()
        print("L1 loss_fn set.")
    elif args.loss_type == "mse":
        loss_fn = criteria.MaskedMSELoss().cuda()
        print("MSE loss_fn set.")
    elif args.loss_type == "ssim":
        loss_fn = criteria.SsimLoss().cuda()
        print("Ssim loss_fn set.")
    elif args.loss_type == "three":
        loss_fn = criteria.Ssim_grad_L1().cuda()
        print("SSIM+L1+Grad loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('./result/input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('./result/gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")

            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=step_size,
                       gamma=step_gamma)  # may change to other value

    for epoch in range(num_epochs):

        # 4.Optim

        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            print("depth", depth)
            if isDataAug:
                depth = depth * 1000
                depth = torch.clamp(depth, 10, 1000)
                depth = max_depth / depth

            input_var = Variable(
                input.type(dtype))  # variable is for derivative
            depth_var = Variable(
                depth.type(dtype))  # variable is for derivative
            # print("depth_var",depth_var)

            output = model(input_var)

            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:

                if isDataAug:
                    depth = depth * 1000
                    depth = torch.clamp(depth, 10, 1000)
                    depth = max_depth / depth

                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                # normalization
                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                    input_rgb_image)
                plot.imsave(
                    './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                            epoch + 1),
                    input_gt_depth_image,
                    cmap="viridis")
                plot.imsave(
                    './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                              epoch + 1),
                    pred_depth_image,
                    cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

                if epoch % 10 == 9:
                    PATH = args.loss_type + '.pth'
                    torch.save(model.state_dict(), PATH)

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        scheduler.step()
Example #4
0
import torch.nn as nn
import random
import numpy as np
import time
import rospy
from fcrn import FCRN # this file could be changed
import cv2
batch_size = 1
from torch.autograd import Variable
dtype = torch.cuda.FloatTensor

depth_net = FCRN(batch_size)
resume_file_depth = 'checkpoint.pth.tar'
checkpoint = torch.load(resume_file_depth)
depth_net.load_state_dict(checkpoint['state_dict'])
depth_net.cuda()

# img = cv2.imread("299902454.png") * 2
img = cv2.imread("./test_images/599.png")
print img.shape
img = cv2.resize(img, (640, 480), interpolation=cv2.INTER_NEAREST)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = cv2.imread("rgb_depth.png")
# img[:, :, 2] = img[:, :, 2] * 0.2
# img[:, :, 1] = img[:, :, 1] * 0.85
# img[:, :, 0] = img[:, :, 0] * 0.88
# print img.shape
# cv2.imwrite("rgb_depth_pre.png", img)
img = torch.from_numpy(img[np.newaxis, :])
img = img.permute(0, 3, 1, 2)
img = Variable(img.type(dtype))
Example #5
0
def main():
    batch_size = 16
    data_path = './data/nyu_depth_v2_labeled.mat'
    learning_rate = 1.0e-4
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 100

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data...")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=False,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model...")
    model = FCRN(batch_size)
    model.load_state_dict(load_weights(model, weights_file,
                                       dtype))  #加载官方参数,从tensorflow转过来
    #加载训练模型
    resume_from_file = False
    resume_file = './model/model_300.pth'
    if resume_from_file:
        if os.path.isfile(resume_file):
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("can not find!")
    model = model.cuda()

    # 3.Loss
    # 官方MSE
    # loss_fn = torch.nn.MSELoss()
    # 自定义MSE
    # loss_fn = loss_mse()
    # 论文的loss,the reverse Huber
    loss_fn = loss_huber()
    print("loss_fn set...")

    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    print("optimizer set...")

    # 5.Train
    best_val_err = 1.0e-4
    start_epoch = 0

    for epoch in range(num_epochs):
        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs + start_epoch))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0
        for input, depth in train_loader:

            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss: %f' % loss.data.cpu().item())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)
                if num_epochs == epoch + 1:
                    # 关于保存的测试图片可以参考 loader 的写法
                    # input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
                    input_rgb_image = input[0].data.permute(1, 2, 0)
                    input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                    ).astype(np.float32)
                    pred_depth_image = output[0].data.squeeze().cpu().numpy(
                    ).astype(np.float32)

                    input_gt_depth_image /= np.max(input_gt_depth_image)
                    pred_depth_image /= np.max(pred_depth_image)

                    plot.imsave(
                        './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                                 epoch + 1),
                        input_rgb_image)
                    plot.imsave(
                        './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                                epoch + 1),
                        input_gt_depth_image,
                        cmap="viridis")
                    plot.imsave(
                        './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                                  epoch + 1),
                        pred_depth_image,
                        cmap="viridis")

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error: %f' % err)

        if err < best_val_err or epoch == num_epochs - 1:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, './model/model_' + str(start_epoch + epoch + 1) + '.pth')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.8
def main():
    batch_size = 32
    data_path = 'augmented_dataset.mat'
    learning_rate = 1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 50
    resume_from_file = False

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model......")
    model = FCRN(batch_size)
    #resnet = torchvision.models.resnet50(pretrained=True)
    resnet = torchvision.models.resnet50()
    resnet.load_state_dict(torch.load('/content/model/resnet50-19c8e357.pth'))
    #resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    print("resnet50 loaded.")
    resnet50_pretrained_dict = resnet.state_dict()

    #-----------------------------------------------------------
    #model.load_state_dict(load_weights(model, weights_file, dtype))
    chkp = torch.load('our_checkpoint.pth.tar')
    model.load_state_dict(chkp['state_dict'])
    #-----------------------------------------------------------
    """
    print('\nresnet50 keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict = model.state_dict()
    """
    print('\nmodel keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())

    print("resnet50.dict loaded.")
    """
    # load pretrained weights
    #resnet50_pretrained_dict = {k: v for k, v in resnet50_pretrained_dict.items() if k in model_dict}
    print("resnet50_pretrained_dict loaded.")
    """
    print('\nresnet50_pretrained keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict.update(resnet50_pretrained_dict)
    print("model_dict updated.")
    """
    print('\nupdated model dict keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())
    """
    #model.load_state_dict(model_dict)
    print("model_dict loaded.")
    model = model.cuda()

    # 3.Loss
    loss_fn = torch.nn.MSELoss().cuda()
    print("loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")
            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    for epoch in range(num_epochs):

        # 4.Optim
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            # input, depth = data
            #input_var = input.cuda()
            #depth_var = depth.cuda()
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    'input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1),
                    input_rgb_image)
                plot.imsave('gt_depth_epoch_{}.png'.format(start_epoch +
                                                           epoch + 1),
                            input_gt_depth_image,
                            cmap="viridis")
                plot.imsave('pred_depth_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                            pred_depth_image,
                            cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.6
Example #7
0
from torch.autograd import Variable
from torchvision import transforms
from fcrn import FCRN
from transformer_net import TransformerNet

image_num = 1
dtype = torch.cuda.FloatTensor

FCRN_path = 'model/model_300.pth'
resume_FCRN_from_file = True
#TransformerNet_path = 'model/saved_models/rain_princess.pth'
TransformerNet_path = 'model/saved_models/rain_princess.pth'
resume_TransformerNet_from_file = True

model_FCRN = FCRN(image_num)
model_FCRN = model_FCRN.cuda()
model_TransformerNet = TransformerNet()
model_TransformerNet = model_TransformerNet.cuda()

content_image_path_1 = 'data/images/content_image/test1.jpg'  #暗色调房间
content_image_path_2 = 'data/images/content_image/COCO_train2014_000000301334.jpg'  #亮色调房间
content_image_path_3 = 'data/images/content_image/COCO_train2014_000000001355.jpg'  #有窗房间
content_image_path_4 = 'data/images/content_image/COCO_train2014_000000579045.jpg'  #火车
content_image_path_5 = 'data/images/content_image/COCO_train2014_000000304067.jpg'  #明亮街道
content_image_path_6 = 'data/images/content_image/test2.jpg'
content_image = Image.open(content_image_path_2)

# input FCRN parameters
if resume_FCRN_from_file:
    if os.path.isfile(FCRN_path):
        print("=> loading checkpoint '{}'".format(FCRN_path))