Exemplo n.º 1
0
def TestNetwork():
	# define and load depth_prediction network
	depth_net = FCRN(batch_size)
	resume_file_depth = 'checkpoint.pth.tar'
	checkpoint = torch.load(resume_file_depth)
	depth_net.load_state_dict(checkpoint['state_dict'])
	depth_net.cuda()

	# define and load q_learning network
	online_net = QNetwork()
	resume_file_online = 'online_with_noise.pth.tar'
	checkpoint_online = torch.load(resume_file_online)
	online_net.load_state_dict(checkpoint_online['state_dict'])
	online_net.cuda()
	rospy.sleep(1.)

	# Initialize the World and variables
	env = RealWorld()
	print('Environment initialized')
	episode = 0

	# start training
	rate = rospy.Rate(3)

	with torch.no_grad():
		while not rospy.is_shutdown():
			episode += 1
			t = 0

			rgb_img_t1 = env.GetRGBImageObservation()
			rgb_img_t1 = rgb_img_t1[np.newaxis, :]
			rgb_img_t1 = torch.from_numpy(rgb_img_t1)
			rgb_img_t1 = rgb_img_t1.permute(0, 3, 1, 2)
			rgb_img_t1 = Variable(rgb_img_t1.type(dtype))
			depth_img_t1 = depth_net(rgb_img_t1) - 0.2
			depth_img_t1 = torch.squeeze(depth_img_t1, 1)
			depth_imgs_t1 = torch.stack((depth_img_t1, depth_img_t1, depth_img_t1, depth_img_t1), dim=1)

			while not rospy.is_shutdown():
				rgb_img_t1 = env.GetRGBImageObservation()
				# cv2.imwrite('rgb_depth.png', rgb_img_t1)
				rgb_img_t1 = rgb_img_t1[np.newaxis, :]
				rgb_img_t1 = torch.from_numpy(rgb_img_t1)
				rgb_img_t1 = rgb_img_t1.permute(0, 3, 1, 2)
				rgb_img_t1 = Variable(rgb_img_t1.type(dtype))
				depth_img_t1 = depth_net(rgb_img_t1) - 0.2
				depth_imgs_t1 = torch.cat((depth_img_t1, depth_imgs_t1[:, :(IMAGE_HIST - 1), :, :]), 1)
				depth_imgs_t1_cuda = Variable(depth_imgs_t1.type(dtype))
				predicted_depth = depth_img_t1[0].data.squeeze().cpu().numpy().astype(np.float32)
				cv2.imwrite('predicted_depth.png', predicted_depth * 50)


				Q_value_list = online_net(depth_imgs_t1_cuda)
				Q_value_list = Q_value_list[0]
				Q_value, action = torch.max(Q_value_list, 0)
				env.Control(action)
				t += 1
				rate.sleep()
Exemplo n.º 2
0
RMSE_log_scale_invariant = 0.0
ARD = 0.0
SRD = 0.0

model = FCRN(batch_size)
model = model.cuda()
loss_fn = torch.nn.MSELoss().cuda()

resume_file = 'checkpoint.pth.tar'

if resume_from_file:
    if os.path.isfile(resume_file):
        print("=> loading checkpoint '{}'".format(resume_file))
        checkpoint = torch.load(resume_file)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            resume_file, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(resume_file))

_, _, test_lists = load_split()
num_samples = len(test_lists)

test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
    data_path, test_lists),
                                          batch_size=batch_size,
                                          shuffle=False,
                                          drop_last=False)
model.eval()
idx = 0
Exemplo n.º 3
0
def main():

    resume = True
    path = 'data/NYU_DEPTH'
    batch_size = 16
    epochs = 10000
    device = torch.device('cuda:0')
    print_every = 5
    # exp_name = 'resnet18_nodropout_new'
    exp_name = 'only_depth'
    # exp_name = 'normal_internel'
    # exp_name = 'sep'
    lr = 1e-5
    weight_decay = 0.0005
    log_dir = os.path.join('logs', exp_name)
    model_dir = os.path.join('checkpoints', exp_name)
    val_every = 16
    save_every = 16


    # tensorboard
    # remove old log is not to resume
    if not resume:
        if os.path.exists(log_dir):
            shutil.rmtree(log_dir)
            os.makedirs(log_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    tb = SummaryWriter(log_dir)
    tb.add_custom_scalars({
        'metrics': {
            'thres_1.25': ['Multiline', ['thres_1.25/train', 'thres_1.25/test']],
            'thres_1.25_2': ['Multiline', ['thres_1.25_2/train', 'thres_1.25_2/test']],
            'thres_1.25_3': ['Multiline', ['thres_1.25_3/train', 'thres_1.25_3/test']],
            'ard': ['Multiline', ['ard/train', 'ard/test']],
            'srd': ['Multiline', ['srd/train', 'srd/test']],
            'rmse_linear': ['Multiline', ['rmse_linear/train', 'rmse_linear/test']],
            'rmse_log': ['Multiline', ['rmse_log/train', 'rmse_log/test']],
            'rmse_log_invariant': ['Multiline', ['rmse_log_invariant/train', 'rmse_log_invariant/test']],
        }
    })
    
    
    # data loader
    dataset = NYUDepth(path, 'train')
    dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=4)
    
    dataset_test = NYUDepth(path, 'test')
    dataloader_test = DataLoader(dataset_test, batch_size, shuffle=True, num_workers=4)
    
    
    # load model
    model = FCRN(True)
    model = model.to(device)
    
    
    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    start_epoch = 0
    if resume:
        model_path = os.path.join(model_dir, 'model.pth')
        if os.path.exists(model_path):
            print('Loading checkpoint from {}...'.format(model_path))
            # load model and optimizer
            checkpoint = torch.load(os.path.join(model_dir, 'model.pth'), map_location='cpu')
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            print('Model loaded.')
        else:
            print('No checkpoint found. Train from scratch')
    
    # training
    metric_logger = MetricLogger()
    
    end = time.perf_counter()
    max_iters = epochs * len(dataloader)
    
    def normal_loss(pred, normal, conf):
        """
        :param pred: (B, 3, H, W)
        :param normal: (B, 3, H, W)
        :param conf: 1
        """
        dot_prod = (pred * normal).sum(dim=1)
        # weighted loss, (B, )
        batch_loss = ((1 - dot_prod) * conf[:, 0]).sum(1).sum(1)
        # normalize, to (B, )
        batch_loss /= conf[:, 0].sum(1).sum(1)
        return batch_loss.mean()

    def consistency_loss(pred, cloud, normal, conf):
        """
        :param pred: (B, 1, H, W)
        :param normal: (B, 3, H, W)
        :param cloud: (B, 3, H, W)
        :param conf: (B, 1, H, W)
        """
        B, _, _, _ = normal.size()
        normal = normal.detach()
        cloud = cloud.clone()
        cloud[:, 2:3, :, :] = pred
        # algorithm: use a kernel
        kernel = torch.ones((1, 1, 7, 7), device=pred.device)
        kernel = -kernel
        kernel[0, 0, 3, 3] = 48
    
        cloud_0 = cloud[:, 0:1]
        cloud_1 = cloud[:, 1:2]
        cloud_2 = cloud[:, 2:3]
        diff_0 = F.conv2d(cloud_0, kernel, padding=6, dilation=2)
        diff_1 = F.conv2d(cloud_1, kernel, padding=6, dilation=2)
        diff_2 = F.conv2d(cloud_2, kernel, padding=6, dilation=2)
        # (B, 3, H, W)
        diff = torch.cat((diff_0, diff_1, diff_2), dim=1)
        # normalize
        diff = F.normalize(diff, dim=1)
        # (B, 1, H, W)
        dot_prod = (diff * normal).sum(dim=1, keepdim=True)
        # weighted mean over image
        dot_prod = torch.abs(dot_prod.view(B, -1))
        conf = conf.view(B, -1)
        loss = (dot_prod * conf).sum(1) / conf.sum(1)
        # mean over batch
        return loss.mean()
    
    def criterion(depth_pred, normal_pred, depth, normal, cloud, conf):
        mse_loss = F.mse_loss(depth_pred, depth)
        consis_loss = consistency_loss(depth_pred, cloud, normal_pred, conf)
        norm_loss = normal_loss(normal_pred, normal, conf)
        consis_loss = torch.zeros_like(norm_loss)
        
        return mse_loss, mse_loss, mse_loss
        # return mse_loss, consis_loss, norm_loss
        # return norm_loss, norm_loss, norm_loss
    
    print('Start training')
    for epoch in range(start_epoch, epochs):
        # train
        model.train()
        for i, data in enumerate(dataloader):
            start = end
            i += 1
            data = [x.to(device) for x in data]
            image, depth, normal, conf, cloud = data
            depth_pred, normal_pred = model(image)
            mse_loss, consis_loss, norm_loss = criterion(depth_pred, normal_pred, depth, normal, cloud, conf)
            loss = mse_loss + consis_loss + norm_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # bookkeeping
            end = time.perf_counter()
            metric_logger.update(loss=loss.item())
            metric_logger.update(mse_loss=mse_loss.item())
            metric_logger.update(norm_loss=norm_loss.item())
            metric_logger.update(consis_loss=consis_loss.item())
            metric_logger.update(batch_time=end-start)

            
            if i % print_every == 0:
                # Compute eta. global step: starting from 1
                global_step = epoch * len(dataloader) + i
                seconds = (max_iters - global_step) * metric_logger['batch_time'].global_avg
                eta = datetime.timedelta(seconds=int(seconds))
                # to display: eta, epoch, iteration, loss, batch_time
                display_dict = {
                    'eta': eta,
                    'epoch': epoch,
                    'iter': i,
                    'loss': metric_logger['loss'].median,
                    'batch_time': metric_logger['batch_time'].median
                }
                display_str = [
                    'eta: {eta}s',
                    'epoch: {epoch}',
                    'iter: {iter}',
                    'loss: {loss:.4f}',
                    'batch_time: {batch_time:.4f}s',
                ]
                print(', '.join(display_str).format(**display_dict))
                
                # tensorboard
                min_depth = depth[0].min()
                max_depth = depth[0].max() * 1.25
                depth = (depth[0] - min_depth) / (max_depth - min_depth)
                depth_pred = (depth_pred[0] - min_depth) / (max_depth - min_depth)
                depth_pred = torch.clamp(depth_pred, min=0.0, max=1.0)
                normal = (normal[0] + 1) / 2
                normal_pred = (normal_pred[0] + 1) / 2
                conf = conf[0]
                
                tb.add_scalar('train/loss', metric_logger['loss'].median, global_step)
                tb.add_scalar('train/mse_loss', metric_logger['mse_loss'].median, global_step)
                tb.add_scalar('train/consis_loss', metric_logger['consis_loss'].median, global_step)
                tb.add_scalar('train/norm_loss', metric_logger['norm_loss'].median, global_step)
                
                tb.add_image('train/depth', depth, global_step)
                tb.add_image('train/normal', normal, global_step)
                tb.add_image('train/depth_pred', depth_pred, global_step)
                tb.add_image('train/normal_pred', normal_pred, global_step)
                tb.add_image('train/conf', conf, global_step)
                tb.add_image('train/image', image[0], global_step)
                
        if (epoch) % val_every == 0 and epoch != 0:
            # validate after each epoch
            validate(dataloader, model, device, tb, epoch, 'train')
            validate(dataloader_test, model, device, tb, epoch, 'test')
        if (epoch) % save_every == 0 and epoch != 0:
            to_save = {
                'optimizer': optimizer.state_dict(),
                'model': model.state_dict(),
                'epoch': epoch,
            }
            torch.save(to_save, os.path.join(model_dir, 'model.pth'))
Exemplo n.º 4
0
    RMSE_log_scale_invariant = 0.0
    ARD = 0.0
    SRD = 0.0

    model = FCRN(batch_size)
    model = model.cuda()
    loss_fn = torch.nn.MSELoss().cuda()

    resume_file = 'berhu.pth'

    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            # start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint)
            # print("=> loaded checkpoint '{}' (epoch {})"
            #       .format(resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    _, _, test_lists = load_split()
    num_samples = len(test_lists)
    print(num_samples)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=False)
    model.eval()
    idx = 0
def main():
    batch_size = args.batch_size
    data_path = 'nyu_depth_v2_labeled.mat'
    learning_rate = args.lr  #1.0e-4 #1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = args.epochs
    step_size = args.step_size
    step_gamma = args.step_gamma
    resume_from_file = False
    isDataAug = args.data_aug
    max_depth = 1000

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=True)
    print(train_loader)

    # 2.set the model
    print("Set the model......")
    model = FCRN(batch_size)
    resnet = torchvision.models.resnet50()

    # 加载训练到一半的模型
    # resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    # print("resnet50 params loaded.")

    # model.load_state_dict(load_weights(model, weights_file, dtype))

    model = model.cuda()

    # 3.Loss
    # loss_fn = torch.nn.MSELoss().cuda()
    if args.loss_type == "berhu":
        loss_fn = criteria.berHuLoss().cuda()
        print("berhu loss_fn set.")
    elif args.loss_type == "L1":
        loss_fn = criteria.MaskedL1Loss().cuda()
        print("L1 loss_fn set.")
    elif args.loss_type == "mse":
        loss_fn = criteria.MaskedMSELoss().cuda()
        print("MSE loss_fn set.")
    elif args.loss_type == "ssim":
        loss_fn = criteria.SsimLoss().cuda()
        print("Ssim loss_fn set.")
    elif args.loss_type == "three":
        loss_fn = criteria.Ssim_grad_L1().cuda()
        print("SSIM+L1+Grad loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('./result/input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('./result/gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")

            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=step_size,
                       gamma=step_gamma)  # may change to other value

    for epoch in range(num_epochs):

        # 4.Optim

        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            print("depth", depth)
            if isDataAug:
                depth = depth * 1000
                depth = torch.clamp(depth, 10, 1000)
                depth = max_depth / depth

            input_var = Variable(
                input.type(dtype))  # variable is for derivative
            depth_var = Variable(
                depth.type(dtype))  # variable is for derivative
            # print("depth_var",depth_var)

            output = model(input_var)

            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:

                if isDataAug:
                    depth = depth * 1000
                    depth = torch.clamp(depth, 10, 1000)
                    depth = max_depth / depth

                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                # normalization
                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                    input_rgb_image)
                plot.imsave(
                    './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                            epoch + 1),
                    input_gt_depth_image,
                    cmap="viridis")
                plot.imsave(
                    './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                              epoch + 1),
                    pred_depth_image,
                    cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

                if epoch % 10 == 9:
                    PATH = args.loss_type + '.pth'
                    torch.save(model.state_dict(), PATH)

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        scheduler.step()
Exemplo n.º 6
0
import torch
import torch.nn as nn
import random
import numpy as np
import time
import rospy
from fcrn import FCRN # this file could be changed
import cv2
batch_size = 1
from torch.autograd import Variable
dtype = torch.cuda.FloatTensor

depth_net = FCRN(batch_size)
resume_file_depth = 'checkpoint.pth.tar'
checkpoint = torch.load(resume_file_depth)
depth_net.load_state_dict(checkpoint['state_dict'])
depth_net.cuda()

# img = cv2.imread("299902454.png") * 2
img = cv2.imread("./test_images/599.png")
print img.shape
img = cv2.resize(img, (640, 480), interpolation=cv2.INTER_NEAREST)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = cv2.imread("rgb_depth.png")
# img[:, :, 2] = img[:, :, 2] * 0.2
# img[:, :, 1] = img[:, :, 1] * 0.85
# img[:, :, 0] = img[:, :, 0] * 0.88
# print img.shape
# cv2.imwrite("rgb_depth_pre.png", img)
img = torch.from_numpy(img[np.newaxis, :])
img = img.permute(0, 3, 1, 2)
Exemplo n.º 7
0
def main():
    batch_size = 16
    data_path = './data/nyu_depth_v2_labeled.mat'
    learning_rate = 1.0e-4
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 100

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data...")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=False,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model...")
    model = FCRN(batch_size)
    model.load_state_dict(load_weights(model, weights_file,
                                       dtype))  #加载官方参数,从tensorflow转过来
    #加载训练模型
    resume_from_file = False
    resume_file = './model/model_300.pth'
    if resume_from_file:
        if os.path.isfile(resume_file):
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("can not find!")
    model = model.cuda()

    # 3.Loss
    # 官方MSE
    # loss_fn = torch.nn.MSELoss()
    # 自定义MSE
    # loss_fn = loss_mse()
    # 论文的loss,the reverse Huber
    loss_fn = loss_huber()
    print("loss_fn set...")

    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    print("optimizer set...")

    # 5.Train
    best_val_err = 1.0e-4
    start_epoch = 0

    for epoch in range(num_epochs):
        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs + start_epoch))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0
        for input, depth in train_loader:

            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss: %f' % loss.data.cpu().item())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)
                if num_epochs == epoch + 1:
                    # 关于保存的测试图片可以参考 loader 的写法
                    # input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
                    input_rgb_image = input[0].data.permute(1, 2, 0)
                    input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                    ).astype(np.float32)
                    pred_depth_image = output[0].data.squeeze().cpu().numpy(
                    ).astype(np.float32)

                    input_gt_depth_image /= np.max(input_gt_depth_image)
                    pred_depth_image /= np.max(pred_depth_image)

                    plot.imsave(
                        './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                                 epoch + 1),
                        input_rgb_image)
                    plot.imsave(
                        './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                                epoch + 1),
                        input_gt_depth_image,
                        cmap="viridis")
                    plot.imsave(
                        './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                                  epoch + 1),
                        pred_depth_image,
                        cmap="viridis")

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error: %f' % err)

        if err < best_val_err or epoch == num_epochs - 1:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, './model/model_' + str(start_epoch + epoch + 1) + '.pth')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.8
def main():
    batch_size = 32
    data_path = 'augmented_dataset.mat'
    learning_rate = 1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 50
    resume_from_file = False

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model......")
    model = FCRN(batch_size)
    #resnet = torchvision.models.resnet50(pretrained=True)
    resnet = torchvision.models.resnet50()
    resnet.load_state_dict(torch.load('/content/model/resnet50-19c8e357.pth'))
    #resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    print("resnet50 loaded.")
    resnet50_pretrained_dict = resnet.state_dict()

    #-----------------------------------------------------------
    #model.load_state_dict(load_weights(model, weights_file, dtype))
    chkp = torch.load('our_checkpoint.pth.tar')
    model.load_state_dict(chkp['state_dict'])
    #-----------------------------------------------------------
    """
    print('\nresnet50 keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict = model.state_dict()
    """
    print('\nmodel keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())

    print("resnet50.dict loaded.")
    """
    # load pretrained weights
    #resnet50_pretrained_dict = {k: v for k, v in resnet50_pretrained_dict.items() if k in model_dict}
    print("resnet50_pretrained_dict loaded.")
    """
    print('\nresnet50_pretrained keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict.update(resnet50_pretrained_dict)
    print("model_dict updated.")
    """
    print('\nupdated model dict keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())
    """
    #model.load_state_dict(model_dict)
    print("model_dict loaded.")
    model = model.cuda()

    # 3.Loss
    loss_fn = torch.nn.MSELoss().cuda()
    print("loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")
            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    for epoch in range(num_epochs):

        # 4.Optim
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            # input, depth = data
            #input_var = input.cuda()
            #depth_var = depth.cuda()
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    'input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1),
                    input_rgb_image)
                plot.imsave('gt_depth_epoch_{}.png'.format(start_epoch +
                                                           epoch + 1),
                            input_gt_depth_image,
                            cmap="viridis")
                plot.imsave('pred_depth_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                            pred_depth_image,
                            cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.6
Exemplo n.º 9
0
content_image_path_1 = 'data/images/content_image/test1.jpg'  #暗色调房间
content_image_path_2 = 'data/images/content_image/COCO_train2014_000000301334.jpg'  #亮色调房间
content_image_path_3 = 'data/images/content_image/COCO_train2014_000000001355.jpg'  #有窗房间
content_image_path_4 = 'data/images/content_image/COCO_train2014_000000579045.jpg'  #火车
content_image_path_5 = 'data/images/content_image/COCO_train2014_000000304067.jpg'  #明亮街道
content_image_path_6 = 'data/images/content_image/test2.jpg'
content_image = Image.open(content_image_path_2)

# input FCRN parameters
if resume_FCRN_from_file:
    if os.path.isfile(FCRN_path):
        print("=> loading checkpoint '{}'".format(FCRN_path))
        FCRN_par = torch.load(FCRN_path)
        start_epoch = FCRN_par['epoch']
        model_FCRN.load_state_dict(FCRN_par['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            FCRN_path, FCRN_par['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(FCRN_path))

# input TransformerNet parameters
if resume_TransformerNet_from_file:
    if os.path.isfile(TransformerNet_path):
        print("=> loading checkpoint '{}'".format(TransformerNet_path))
        TransformerNet_par = torch.load(TransformerNet_path)
        for k in list(TransformerNet_par.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del TransformerNet_par[k]
        model_TransformerNet.load_state_dict(TransformerNet_par)
        print("=> loaded checkpoint '{}'".format(TransformerNet_path))