Exemplo n.º 1
0
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            resume_file, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(resume_file))

_, _, test_lists = load_split()
num_samples = len(test_lists)

test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
    data_path, test_lists),
                                          batch_size=batch_size,
                                          shuffle=False,
                                          drop_last=False)
model.eval()
idx = 0
mean_difference = 0.
with torch.no_grad():
    for input, gt in test_loader:
        input_var = Variable(input.type(dtype))
        gt_var = Variable(gt.type(dtype))

        #input_var = input_var.unsqueeze(0)

        output = model(input_var)

        print('predict complete.')

        #input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        input_gt_depth_image = gt_var[0].data.squeeze().cpu().numpy().astype(
def main():
    batch_size = args.batch_size
    data_path = 'nyu_depth_v2_labeled.mat'
    learning_rate = args.lr  #1.0e-4 #1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = args.epochs
    step_size = args.step_size
    step_gamma = args.step_gamma
    resume_from_file = False
    isDataAug = args.data_aug
    max_depth = 1000

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=True)
    print(train_loader)

    # 2.set the model
    print("Set the model......")
    model = FCRN(batch_size)
    resnet = torchvision.models.resnet50()

    # 加载训练到一半的模型
    # resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    # print("resnet50 params loaded.")

    # model.load_state_dict(load_weights(model, weights_file, dtype))

    model = model.cuda()

    # 3.Loss
    # loss_fn = torch.nn.MSELoss().cuda()
    if args.loss_type == "berhu":
        loss_fn = criteria.berHuLoss().cuda()
        print("berhu loss_fn set.")
    elif args.loss_type == "L1":
        loss_fn = criteria.MaskedL1Loss().cuda()
        print("L1 loss_fn set.")
    elif args.loss_type == "mse":
        loss_fn = criteria.MaskedMSELoss().cuda()
        print("MSE loss_fn set.")
    elif args.loss_type == "ssim":
        loss_fn = criteria.SsimLoss().cuda()
        print("Ssim loss_fn set.")
    elif args.loss_type == "three":
        loss_fn = criteria.Ssim_grad_L1().cuda()
        print("SSIM+L1+Grad loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('./result/input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('./result/gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")

            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=step_size,
                       gamma=step_gamma)  # may change to other value

    for epoch in range(num_epochs):

        # 4.Optim

        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            print("depth", depth)
            if isDataAug:
                depth = depth * 1000
                depth = torch.clamp(depth, 10, 1000)
                depth = max_depth / depth

            input_var = Variable(
                input.type(dtype))  # variable is for derivative
            depth_var = Variable(
                depth.type(dtype))  # variable is for derivative
            # print("depth_var",depth_var)

            output = model(input_var)

            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:

                if isDataAug:
                    depth = depth * 1000
                    depth = torch.clamp(depth, 10, 1000)
                    depth = max_depth / depth

                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                # normalization
                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                    input_rgb_image)
                plot.imsave(
                    './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                            epoch + 1),
                    input_gt_depth_image,
                    cmap="viridis")
                plot.imsave(
                    './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                              epoch + 1),
                    pred_depth_image,
                    cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

                if epoch % 10 == 9:
                    PATH = args.loss_type + '.pth'
                    torch.save(model.state_dict(), PATH)

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        scheduler.step()
def main():
    batch_size = 32
    data_path = 'augmented_dataset.mat'
    learning_rate = 1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 50
    resume_from_file = False

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model......")
    model = FCRN(batch_size)
    #resnet = torchvision.models.resnet50(pretrained=True)
    resnet = torchvision.models.resnet50()
    resnet.load_state_dict(torch.load('/content/model/resnet50-19c8e357.pth'))
    #resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    print("resnet50 loaded.")
    resnet50_pretrained_dict = resnet.state_dict()

    #-----------------------------------------------------------
    #model.load_state_dict(load_weights(model, weights_file, dtype))
    chkp = torch.load('our_checkpoint.pth.tar')
    model.load_state_dict(chkp['state_dict'])
    #-----------------------------------------------------------
    """
    print('\nresnet50 keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict = model.state_dict()
    """
    print('\nmodel keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())

    print("resnet50.dict loaded.")
    """
    # load pretrained weights
    #resnet50_pretrained_dict = {k: v for k, v in resnet50_pretrained_dict.items() if k in model_dict}
    print("resnet50_pretrained_dict loaded.")
    """
    print('\nresnet50_pretrained keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict.update(resnet50_pretrained_dict)
    print("model_dict updated.")
    """
    print('\nupdated model dict keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())
    """
    #model.load_state_dict(model_dict)
    print("model_dict loaded.")
    model = model.cuda()

    # 3.Loss
    loss_fn = torch.nn.MSELoss().cuda()
    print("loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")
            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    for epoch in range(num_epochs):

        # 4.Optim
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            # input, depth = data
            #input_var = input.cuda()
            #depth_var = depth.cuda()
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    'input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1),
                    input_rgb_image)
                plot.imsave('gt_depth_epoch_{}.png'.format(start_epoch +
                                                           epoch + 1),
                            input_gt_depth_image,
                            cmap="viridis")
                plot.imsave('pred_depth_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                            pred_depth_image,
                            cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.6
Exemplo n.º 4
0
def main():
    batch_size = 16
    data_path = './data/nyu_depth_v2_labeled.mat'
    learning_rate = 1.0e-4
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 100

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data...")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=False,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model...")
    model = FCRN(batch_size)
    model.load_state_dict(load_weights(model, weights_file,
                                       dtype))  #加载官方参数,从tensorflow转过来
    #加载训练模型
    resume_from_file = False
    resume_file = './model/model_300.pth'
    if resume_from_file:
        if os.path.isfile(resume_file):
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("can not find!")
    model = model.cuda()

    # 3.Loss
    # 官方MSE
    # loss_fn = torch.nn.MSELoss()
    # 自定义MSE
    # loss_fn = loss_mse()
    # 论文的loss,the reverse Huber
    loss_fn = loss_huber()
    print("loss_fn set...")

    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    print("optimizer set...")

    # 5.Train
    best_val_err = 1.0e-4
    start_epoch = 0

    for epoch in range(num_epochs):
        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs + start_epoch))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0
        for input, depth in train_loader:

            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss: %f' % loss.data.cpu().item())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)
                if num_epochs == epoch + 1:
                    # 关于保存的测试图片可以参考 loader 的写法
                    # input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
                    input_rgb_image = input[0].data.permute(1, 2, 0)
                    input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                    ).astype(np.float32)
                    pred_depth_image = output[0].data.squeeze().cpu().numpy(
                    ).astype(np.float32)

                    input_gt_depth_image /= np.max(input_gt_depth_image)
                    pred_depth_image /= np.max(pred_depth_image)

                    plot.imsave(
                        './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                                 epoch + 1),
                        input_rgb_image)
                    plot.imsave(
                        './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                                epoch + 1),
                        input_gt_depth_image,
                        cmap="viridis")
                    plot.imsave(
                        './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                                  epoch + 1),
                        pred_depth_image,
                        cmap="viridis")

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error: %f' % err)

        if err < best_val_err or epoch == num_epochs - 1:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, './model/model_' + str(start_epoch + epoch + 1) + '.pth')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.8
Exemplo n.º 5
0
        print("=> no checkpoint found at '{}'".format(FCRN_path))

# input TransformerNet parameters
if resume_TransformerNet_from_file:
    if os.path.isfile(TransformerNet_path):
        print("=> loading checkpoint '{}'".format(TransformerNet_path))
        TransformerNet_par = torch.load(TransformerNet_path)
        for k in list(TransformerNet_par.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del TransformerNet_par[k]
        model_TransformerNet.load_state_dict(TransformerNet_par)
        print("=> loaded checkpoint '{}'".format(TransformerNet_path))
    else:
        print("=> no checkpoint found at '{}'".format(TransformerNet_path))

model_FCRN.eval()
model_TransformerNet.eval()
with torch.no_grad():
    device = torch.device("cuda")

    input = content_image.resize((304, 228), Image.ANTIALIAS)
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Lambda(lambda x: x.mul(255))
    ])
    input_for_style = input_transform(input)
    input_for_depth = transforms.ToTensor()(input)
    input_for_style = input_for_style.unsqueeze(0).to(device)
    input_for_depth = input_for_depth.unsqueeze(0).to(device)

    output_depth = model_FCRN(input_for_depth)