Пример #1
0
from tqdm import tqdm
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
from numpy.linalg import inv

from lib.options import BaseOptions
from lib.mesh_util import save_obj_mesh_with_color, reconstruction
from lib.data import EvalWPoseDataset, EvalDataset
from lib.model import HGPIFuNetwNML, HGPIFuMRNet
from lib.geometry import index

from PIL import Image

parser = BaseOptions()


def gen_mesh(res,
             net,
             cuda,
             data,
             save_path,
             thresh=0.5,
             use_octree=True,
             components=False):
    image_tensor_global = data['img_512'].to(device=cuda)
    image_tensor = data['img'].to(device=cuda)
    calib_tensor = data['calib'].to(device=cuda)

    net.filter_global(image_tensor_global)
Пример #2
0
import torch
from torch.utils.data import DataLoader

from lib.options import BaseOptions
from lib.mesh_util import *
from lib.sample_util import *
from lib.train_util import *
from lib.model import *

from PIL import Image
import torchvision.transforms as transforms
import glob
import tqdm

# get options
opt = BaseOptions().parse()


class Evaluator:
    def __init__(self, opt, projection_mode='orthogonal'):
        self.opt = opt
        self.load_size = self.opt.loadSize
        self.to_tensor = transforms.Compose([
            transforms.Resize(self.load_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # set cuda
        cuda = torch.device('cuda:%d' % opt.gpu_id)

        # create net
Пример #3
0
        # visual sanity check
        if visual_demo_flag:

            print("saving demo for visualCheck...")

            # the input 512x512x3 image
            os.system("cp %s ./sample_images/" % (save_path_png))

            # the estimated mesh .obj
            os.system("cp %s ./sample_images/" % (save_path_obj))

            # the gt high-resolution mesh .obj
            gtMesh = read_and_canonize_gt_mesh(preFix=frameIdx[1],
                                               args=args,
                                               withTexture=True)
            ObjIO.save_obj_data_color(
                gtMesh, "./sample_images/%06d_meshGT.obj" % (frameIdx[1]))

            # the gt low-resolution mesh. obj
            os.system("mv %s ./sample_images/" % (save_path_gt_obj))


if __name__ == '__main__':

    # parse args.
    args = BaseOptions().parse()

    # main function
    main(args=args)
Пример #4
0
def train(opt):
    # set cuda
    device_ids = [int(i) for i in opt.gpu_ids.split(",")]
    cuda = torch.device('cuda:%d' % device_ids[0])

    tb.initWriter(opt)

    train_dataset = TrainDataset(opt, phase='train')
    test_dataset = TrainDataset(opt, phase='test')

    print("Loading training data...")
    coll = MultiViewCollator(opt)
    # create data loader
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=opt.batch_size,
                                   shuffle=not opt.serial_batches,
                                   collate_fn=coll,
                                   num_workers=opt.num_threads,
                                   pin_memory=opt.pin_memory)

    print('train data size: ', len(train_data_loader))
    # NOTE: batch size should be 1 and use all the points for evaluation
    #test_data_loader = DataLoader(test_dataset,
    #                              batch_size=1, shuffle=True, collate_fn=coll,
    #                              num_workers=opt.num_threads, pin_memory=opt.pin_memory)
    print('test data size: ', len(test_dataset))

    # create net
    print("Num GPUs: " + str(torch.cuda.device_count()))

    if len(device_ids) > 1:
        netG = MyDataParallel(HGPIFuNet(opt, train_dataset.projection_mode),
                              device_ids=device_ids).to(device=cuda)
    else:
        netG = HGPIFuNet(opt, train_dataset.projection_mode).to(device=cuda)

    #netG = HGPIFuNet(opt, projection_mode).to(device=cuda)

    #optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.learning_rate)
    optimizerG = torch.optim.RMSprop(netG.parameters(),
                                     lr=opt.learning_rate,
                                     momentum=0,
                                     weight_decay=0)
    lr = opt.learning_rate
    print('Using Network: ', netG.name)

    def set_train():
        netG.train()

    def set_eval():
        netG.eval()

    # load checkpoints
    if opt.load_netG_checkpoint_path is not None:
        print('loading for net G ...', opt.load_netG_checkpoint_path)
        netG.load_state_dict(
            torch.load(opt.load_netG_checkpoint_path, map_location=cuda))

    if opt.continue_train:
        if opt.resume_epoch < 0:
            model_path = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name)
        else:
            model_path = '%s/%s/netG_epoch_%d' % (opt.checkpoints_path,
                                                  opt.name, opt.resume_epoch)
        print('Resuming from ', model_path)
        netG.load_state_dict(torch.load(model_path, map_location=cuda))

    if opt.decoder_base != "":
        model_opt = BaseOptions().loadOptFromFile(name=opt.decoder_base)
        model_with_decoder = HGPIFuNet(
            model_opt, train_dataset.projection_mode).to(device=cuda)
        model_path = '%s/%s/netG_latest' % (opt.checkpoints_path,
                                            opt.decoder_base)
        model_with_decoder.load_state_dict(
            torch.load(model_path, map_location=cuda))
        netG.surface_classifier = model_with_decoder.surface_classifier

        for param in netG.surface_classifier.parameters():
            param.requires_grad = False

        print("Loaded decoder from {0} and froze parameters...".format(
            opt.decoder_base))

    os.makedirs(opt.checkpoints_path, exist_ok=True)
    os.makedirs(opt.results_path, exist_ok=True)
    os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
    os.makedirs('%s/%s/training' % (opt.results_path, opt.name), exist_ok=True)
    os.makedirs('%s/%s/test' % (opt.results_path, opt.name), exist_ok=True)

    BaseOptions().saveOptToFile(opt)

    scaler = torch.cuda.amp.GradScaler()
    # training
    start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch, 0)
    for epoch in range(start_epoch, opt.num_epoch):
        epoch_start_time = time.time()

        set_train()
        iter_data_time = time.time()

        for train_idx, train_data in enumerate(train_data_loader):
            iter_start_time = time.time()

            train_data = move_to_gpu(train_data, cuda)
            pred, nmls, _, error = netG.forward(
                train_data['images'],
                train_data['samples'],
                train_data['calib'],
                imgSizes=train_data['size'],
                labels=train_data['labels'],
                points_surface=train_data['samples_normals'],
                labels_nml=train_data['normals'],
                labels_edges=train_data['edges'])

            scaler.scale(error['Err(cmb)'].mean()).backward()
            scaler.step(optimizerG)
            scaler.update()

            optimizerG.zero_grad()
            netG.zero_grad()

            #optimizerG.step()

            iter_net_time = time.time()
            eta = ((iter_net_time - epoch_start_time) /
                   (train_idx + 1)) * len(train_data_loader) - (
                       iter_net_time - epoch_start_time)

            if train_idx % opt.freq_plot == 0:
                normal_loss = 0
                edge_loss = 0

                if opt.use_normal_loss:
                    normal_loss = error['Err(nml)'].mean().item()
                if opt.use_edge_loss:
                    edge_loss = error['Err(edges)'].mean().item()
                print(
                    'Name: {0} | Epoch: {1} | {2}/{3} | Err (Cmb): {4:.06f} | Err(Occ): {5:.06f} |  Err(Nml): {6:.06f} | Err(edges): {7:.06f} | LR: {8:.06f} | dataT: {9:.05f} | netT: {10:.05f} | ETA: {11:02d}:{12:02d}'
                    .format(opt.name, epoch, train_idx, len(train_data_loader),
                            error['Err(cmb)'].mean().item(),
                            error['Err(occ)'].mean().item(), normal_loss,
                            edge_loss, lr, iter_start_time - iter_data_time,
                            iter_net_time - iter_start_time, int(eta // 60),
                            int(eta - 60 * (eta // 60))))

            if not opt.debug and train_idx % opt.freq_save == 0 and train_idx != 0:
                torch.save(
                    netG.state_dict(),
                    '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name))
                torch.save(
                    netG.state_dict(), '%s/%s/netG_epoch_%d' %
                    (opt.checkpoints_path, opt.name, epoch))

            if train_idx % opt.freq_save_ply == 0:
                save_path = '%s/%s/pred.ply' % (opt.results_path, opt.name)
                r = pred[0].cpu()
                points = train_data['samples'][0].transpose(0, 1).cpu()
                save_samples_truncted_prob(save_path,
                                           points.detach().numpy(),
                                           r.detach().numpy())

                nml_source = pred[0] if opt.predict_normal else nmls[0]
                gt_labels_rgb = (
                    train_data['normals'][0].cpu().detach().numpy() + 1) * 0.5
                pred_labels_rgb = (nml_source.cpu().detach().numpy() + 1) * 0.5

                save_samples_rgb(
                    '%s/%s/normals_gt.ply',
                    train_data['samples_normals'][0].cpu().detach().numpy().T,
                    gt_labels_rgb.T)
                save_samples_rgb(
                    '%s/%s/normals_pred.ply',
                    train_data['samples_normals'][0].cpu().detach().numpy().T,
                    pred_labels_rgb.T)

            iter_data_time = time.time()

        # update learning rate
        if isinstance(optimizerG, torch.optim.RMSprop):
            lr = adjust_learning_rate(optimizerG, epoch, lr, opt.schedule,
                                      opt.gamma)

        if len(device_ids) > 1:
            netG.device_ids = [device_ids[0]]

        #### test
        with torch.no_grad():
            set_eval()

            if not opt.no_num_eval:
                test_losses = {}
                print('calc error (test) ...')
                test_errors = calc_error(coll, netG, cuda, test_dataset, 100)
                print(
                    'eval test OCC: {0:06f} NML: {1:06f} EDGES: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
                    .format(*test_errors))
                occ, nml, edges, IOU, prec, recall = test_errors
                test_losses['OCC(test)'] = occ
                test_losses['NML(test)'] = nml
                test_losses['EDGES(test'] = edges
                test_losses['IOU(test)'] = IOU
                test_losses['prec(test)'] = prec
                test_losses['recall(test)'] = recall

                print('calc error (train) ...')
                train_dataset.is_train = False
                train_errors = calc_error(coll, netG, cuda, train_dataset, 100)
                train_dataset.is_train = True
                print(
                    'eval train OCC: {0:06f} NML: {1:06f} EDGES: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
                    .format(*train_errors))
                occ, nml, edges, IOU, prec, recall = train_errors
                test_losses['OCC(train)'] = occ
                test_losses['NML(train)'] = nml
                test_losses['EDGES(train'] = edges
                test_losses['IOU(train)'] = IOU
                test_losses['prec(train)'] = prec
                test_losses['recall(train)'] = recall

            if not opt.no_gen_mesh:
                print('generate mesh (test) ...')
                test_data = None
                train_data = None
                for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
                    test_data = test_dataset[random.randint(
                        0,
                        len(test_dataset) - 1)]
                    #test_data = test_dataset[6]
                    test_data_batched = coll([test_data])
                    test_data_batched = move_to_gpu(test_data_batched, cuda)

                    save_path = '%s/%s/test/test_eval_epoch%d_%s.obj' % (
                        opt.results_path, opt.name, epoch, test_data['name'])
                    mesh_test = gen_mesh(opt, netG, cuda, test_data_batched,
                                         save_path)

                print('generate mesh (train) ...')
                train_dataset.is_train = False
                for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
                    train_data = train_dataset[random.randint(
                        0,
                        len(test_dataset) - 1)]
                    train_data_batched = coll([train_data])
                    train_data_batched = move_to_gpu(train_data_batched, cuda)

                    save_path = '%s/%s/training/train_eval_epoch%d_%s.obj' % (
                        opt.results_path, opt.name, epoch, train_data['name'])
                    mesh_train = gen_mesh(opt, netG, cuda, train_data_batched,
                                          save_path)
                train_dataset.is_train = True

                tb.updateAfterEpoch(epoch, train_errors, test_errors,
                                    train_data['img'], test_data['img'])

        if len(device_ids) > 1:
            netG.device_ids = device_ids
Пример #5
0
def train(opt):
    # set cuda
    device_ids = [int(i) for i in opt.gpu_ids.split(",")]
    cuda = torch.device('cuda:%d' % device_ids[0])

    tb.initWriter(opt)

    train_dataset = TrainDataset(opt, phase='train')
    test_dataset = TrainDataset(opt, phase='test')

    print("Loading training data...")
    coll = MultiViewCollator(opt)
    # create data loader
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=opt.batch_size, shuffle=not opt.serial_batches,collate_fn=coll,
                                   num_workers=opt.num_threads, pin_memory=opt.pin_memory)

    print('train data size: ', len(train_data_loader))
    # NOTE: batch size should be 1 and use all the points for evaluation
    #test_data_loader = DataLoader(test_dataset,
    #                              batch_size=1, shuffle=True, collate_fn=coll,
    #                              num_workers=opt.num_threads, pin_memory=opt.pin_memory)
    print('test data size: ', len(test_dataset))

    # create net
    print("Num GPUs: " + str(torch.cuda.device_count()))

    name = "multiview_pifu_OCC_hg_bp__256_15000_nml_loss_edge_loss__mlp"
    model_opt = BaseOptions().loadOptFromFile(name= name)
    model_opt.use_edge_loss = False

    if len(device_ids) > 1:
        netG = MyDataParallel(HGPIFuNet(model_opt, train_dataset.projection_mode), device_ids=device_ids).to(device=cuda)
    else:
        netG = HGPIFuNet(model_opt, train_dataset.projection_mode).to(device=cuda)

    model_path = '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, name, 20)
    netG.load_state_dict(torch.load(model_path, map_location=cuda))

    optimizerG = torch.optim.RMSprop(netG.parameters(), lr=opt.learning_rate, momentum=0, weight_decay=0)
    lr = opt.learning_rate
    print('Using Network: ', netG.name)
    
    def set_train():
        netG.train()

    def set_eval():
        netG.eval()

    os.makedirs(opt.checkpoints_path, exist_ok=True)
    os.makedirs(opt.results_path, exist_ok=True)
    os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
    os.makedirs('%s/%s/training' % (opt.results_path, opt.name), exist_ok=True)
    os.makedirs('%s/%s/test' % (opt.results_path, opt.name), exist_ok=True)

    BaseOptions().saveOptToFile(opt)
    vgg_loss = LossNetwork.VGGPerceptualLoss(False).to(device=cuda)
    scaler = torch.cuda.amp.GradScaler()
    # training
    start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch,0)
    for epoch in range(start_epoch, opt.num_epoch):
        epoch_start_time = time.time()

        set_train()
        iter_data_time = time.time()

        for train_idx, train_data in enumerate(train_data_loader):
            iter_start_time = time.time()

            train_data = move_to_gpu(train_data, cuda)
            image_idx = 0

            features = netG.filter(train_data['images'])
            sampled_points = 0
            num_points = 50000
            points = train_data['samples_depth'][image_idx]
            total_points = points.shape[2]

            current_normals_list = []
            while False and sampled_points < total_points:
                num_points = min(num_points, total_points - sampled_points)
                slice = points[:,:,sampled_points:num_points+sampled_points]
                pred = netG.query(slice, train_data['calib'], train_data['size'])
                new_nmls = torch.autograd.grad(outputs=pred, inputs=[slice[0:1,:,:]], grad_outputs=torch.ones_like(slice[0:1,:,:]),create_graph=True, retain_graph=True)
                #new_nmls = netG.forward(reshape_sample_tensor(slice, opt.num_views), train_data['calib'],train_data['size'], fd_type='forward')
                points_query = reshape_sample_tensor(slice, opt.num_views)

                #new_nmls, _ = netG.calc_normal_edges(features, points_query, train_data['calib'], train_data['size'], fd_type='central')
                current_normals_list.append(new_nmls[image_idx::opt.num_views, :, :])
                sampled_points += num_points

            points_query = reshape_sample_tensor(points, opt.num_views)

            gt_normals_img = train_data['normal_images'][image_idx].to(dtype=torch.float32)
            gt_normals_img_mask = (gt_normals_img > 0).to(dtype=torch.float32)

            sdf = netG.calc_pred(features, points_query, train_data['calib'], train_data['size'])

            _, normals, _, errorPred = netG.forward(train_data['images'], train_data['samples'], train_data['calib'], train_data['size'],
                                                   labels=train_data['labels'], points_surface=points_query, labels_nml=None)

            #normals = torch.autograd.grad(outputs=train_data['samples'], inputs=sdf,
            #                               grad_outputs=torch.ones_like(train_data['samples']), create_graph=True,
            #                               retain_graph=True)

            img_orig_shape = train_data['normal_images'][image_idx].shape
            pred_normals_img = torch.reshape(sdf, (-1, 1, img_orig_shape[2], img_orig_shape[3]))
            pred_normals_img = gt_normals_img - (pred_normals_img-0.5)

            #pred_normals_img = torch.reshape(normals, (-1, 3, img_orig_shape[2], img_orig_shape[3]))
            #pred_normals_img = F.normalize(pred_normals_img, dim=1, eps=1e-8)
            #print(torch.min(pred_normals_img), torch.max(pred_normals_img))
            #pred_normals_img = (pred_normals_img + 1) * 0.5
            pred_normals_img *= gt_normals_img_mask

            #plt.imshow(gt_normals_img[0].detach().cpu().numpy().transpose(1,2,0))
            #plt.show()
            #preview = 0.5 - pred_normals_img[0].detach().cpu().numpy().transpose(1,2,0)
            preview = np.concatenate([pred_normals_img[0].detach().cpu().numpy().transpose(1,2,0), gt_normals_img[0].detach().cpu().numpy().transpose(1,2,0)], axis=1)
            cv2.imshow("prediction", preview)
            cv2.waitKey(1)
            #plt.imshow(pred_normals_img[0].detach().cpu().numpy().transpose(1,2,0))
            #plt.show()

            #mse = torch.nn.MSELoss()
            error = 0
            #error = mse(pred_normals_img, gt_normals_img) + errorPred['Err(cmb)'].mean()

            error = vgg_loss(pred_normals_img, gt_normals_img) + errorPred['Err(cmb)'].mean()*10
            scaler.scale(error.mean()).backward()
            scaler.step(optimizerG)
            scaler.update()
            optimizerG.zero_grad()
            netG.zero_grad()

            iter_net_time = time.time()
            eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
                    iter_net_time - epoch_start_time)

            if train_idx % opt.freq_plot == 0:
                print(
                    'Name: {0} | Epoch: {1} | {2}/{3} | Err(VGG): {4:.06f} | LR: {5:.06f} | dataT: {6:.05f} | netT: {7:.05f} | ETA: {8:02d}:{9:02d}'.format(
                        opt.name, epoch, train_idx, len(train_data_loader), error, lr, iter_start_time - iter_data_time, iter_net_time - iter_start_time, int(eta // 60),int(eta - 60 * (eta // 60))))

            if not opt.debug and train_idx % opt.freq_save == 0 and train_idx != 0:
                torch.save(netG.state_dict(), '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name))
                torch.save(netG.state_dict(), '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))

            if train_idx % opt.freq_save_ply == 0:
                save_path = '%s/%s/nml.ply' % (opt.results_path, opt.name)
                normals = normals.detach().cpu().numpy()
                points = points.detach().cpu().numpy()
                save_samples_rgb(save_path, points[0].T, (normals[0].T + 1) / 2)

            iter_data_time = time.time()

        # update learning rate
        if isinstance(optimizerG, torch.optim.RMSprop):
            lr = adjust_learning_rate(optimizerG, epoch, lr, opt.schedule, opt.gamma)

        if len(device_ids) > 1:
            netG.device_ids = [device_ids[0]]

        #### test
        with torch.no_grad():
            set_eval()

            if not opt.no_gen_mesh:
                print('generate mesh (test) ...')
                test_data = None
                train_data = None
                for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
                    #test_data = test_dataset[random.randint(0, len(test_dataset) - 1)]
                    test_data = test_dataset[6]
                    test_data_batched = coll([test_data])
                    test_data_batched = move_to_gpu(test_data_batched, cuda)

                    save_path = '%s/%s/test/test_eval_epoch%d_%s.obj' % (opt.results_path, opt.name, epoch, test_data['name'])
                    mesh_test = gen_mesh(opt, netG, cuda, test_data_batched, save_path)

                print('generate mesh (train) ...')
                train_dataset.is_train = False
                for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
                    train_data = train_dataset[random.randint(0, len(test_dataset) - 1)]
                    train_data_batched = coll([train_data])
                    train_data_batched = move_to_gpu(train_data_batched, cuda)

                    save_path = '%s/%s/training/train_eval_epoch%d_%s.obj' % (opt.results_path, opt.name, epoch, train_data['name'])
                    mesh_train = gen_mesh(opt, netG, cuda, train_data_batched, save_path)
                train_dataset.is_train = True

                #tb.updateAfterEpoch(epoch, train_errors, test_errors, train_data['img'], test_data['img'])

        if len(device_ids) > 1:
            netG.device_ids = device_ids