Ejemplo n.º 1
0
 def __init__(self):
     super().__init__()
     self.net = make_model(conf["model"]).to(device=device)
     self.net.load_weights(args)
     self.renderer = NeRFRenderer.from_conf(
         conf["renderer"],
         white_bkgd=not args.black,
         eval_batch_size=args.ray_batch_size).to(device=device)
Ejemplo n.º 2
0
if args.scale != 1.0:
    Ht = int(H * args.scale)
    Wt = int(W * args.scale)
    if abs(Ht / args.scale - H) > 1e-10 or abs(Wt / args.scale - W) > 1e-10:
        warnings.warn(
            "Inexact scaling, please check {} times ({}, {}) is integral".
            format(args.scale, H, W))
    H, W = Ht, Wt

net = make_model(conf["model"]).to(device=device)
net.load_weights(args)

renderer = NeRFRenderer.from_conf(
    conf["renderer"],
    lindisp=dset.lindisp,
    eval_batch_size=args.ray_batch_size,
).to(device=device)

render_par = renderer.bind_parallel(net, args.gpu_id,
                                    simple_output=True).eval()

# Get the distance from camera to origin
z_near = dset.z_near
z_far = dset.z_far

print("Generating rays")

dtu_format = hasattr(dset, "sub_format") and dset.sub_format == "dtu"

if dtu_format:
Ejemplo n.º 3
0
                                  training=True,
                                  default_ray_batch_size=128)
device = util.get_cuda(args.gpu_id[0])

dset, val_dset, _ = get_split_dataset(args.dataset_format, args.datadir)
print("dset z_near {}, z_far {}, lindisp {}".format(dset.z_near, dset.z_far,
                                                    dset.lindisp))

net = make_model(conf["model"]).to(device=device)
net.stop_encoder_grad = args.freeze_enc
if args.freeze_enc:
    print("Encoder frozen")
    net.encoder.eval()

renderer = NeRFRenderer.from_conf(
    conf["renderer"],
    lindisp=dset.lindisp,
).to(device=device)

# Parallize
render_par = renderer.bind_parallel(net, args.gpu_id).eval()

nviews = list(map(int, args.nviews.split()))


class PixelNeRFTrainer(trainlib.Trainer):
    def __init__(self):
        super().__init__(net,
                         dset,
                         val_dset,
                         args,
                         conf["train"],
Ejemplo n.º 4
0
        # get rows where values are out of bounds and put them back in bounds
        mask = (xyz[0,:]<0) | (xyz[1,:]<0) | (xyz[2,:]<0) | (xyz[0,:]>x_lim) | (xyz[1,:]>y_lim) | (xyz[2,:]>z_lim)
        xyz[:,mask] = 0

        sigma = self.img[tuple(xyz)]
        # Anything out of bounds set back to 0
        sigma[mask] = 0.0
        sigma = sigma.reshape(1, -1, 1)
        rgb = torch.ones(sigma.size(0), sigma.size(1), 3).to(device)
        return torch.cat((rgb, sigma), dim=-1).to(device)


image = CTImage(torch.tensor(arr).to(device))
renderer = NeRFRenderer(
    n_coarse=64, n_fine=32, n_fine_depth=16, depth_std=0.01, sched=[], white_bkgd=False, eval_batch_size=50000
).to(device=device)
render_par = renderer.bind_parallel(image, [0], simple_output=True).eval()

render_rays = util.gen_rays(render_poses, W, H, focal, z_near, z_far).to(device=device)

all_rgb_fine = []
for rays in tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)):
    rgb, _depth = render_par(rays[None])
    all_rgb_fine.append(rgb[0])
_depth = None
rgb_fine = torch.cat(all_rgb_fine)
frames = (rgb_fine.view(num_views, H, W, 3).cpu().numpy() * 255).astype(
    np.uint8
)
Ejemplo n.º 5
0
        sigma = self.img[tuple(xyz)]
        # Anything out of bounds set as air
        sigma[mask] = 0
        sigma = sigma.reshape(1, -1, 1)
        rgb = torch.ones(1, sigma.size(1), 3).to(device)
        return torch.cat((rgb, sigma), dim=-1).to(device)


focal = torch.tensor(focal, dtype=torch.float32, device=device)

# TODO: Change num coarse and fine to take into account each voxel exactly once
image = CTImage(torch.tensor(arr).to(device))
renderer = NeRFRenderer(n_coarse=512,
                        depth_std=0.01,
                        sched=[],
                        white_bkgd=False,
                        composite_x_ray=True,
                        eval_batch_size=50000,
                        lindisp=True).to(device=device)
render_par = renderer.bind_parallel(image, [0], simple_output=True).eval()

render_rays = util.gen_rays_variable_sensor(render_poses, width_pixels,
                                            height_pixels, width, height,
                                            focal, z_near, z_far).to(device)

all_rgb_fine = []
for rays in tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)):
    rgb, _depth = render_par(rays[None])
    all_rgb_fine.append(rgb[0])
_depth = None
rgb_fine = torch.cat(all_rgb_fine)
Ejemplo n.º 6
0
        action="store_true",
        help="Do not store video (only image frames will be written)",
    )
    return parser


args, conf = util.args.parse_args(
    extra_args,
    default_expname="srn_car",
    default_data_format="srn",
)
args.resume = True

device = util.get_cuda(args.gpu_id[0])
net = make_model(conf["model"]).to(device=device).load_weights(args)
renderer = NeRFRenderer.from_conf(
    conf["renderer"], eval_batch_size=args.ray_batch_size).to(device=device)
render_par = renderer.bind_parallel(net, args.gpu_id,
                                    simple_output=True).eval()

z_near, z_far = args.z_near, args.z_far
focal = torch.tensor(args.focal, dtype=torch.float32, device=device)

in_sz = args.size
sz = list(map(int, args.out_size.split()))
if len(sz) == 1:
    H = W = sz[0]
else:
    assert len(sz) == 2
    W, H = sz

_coord_to_blender = util.coord_to_blender()
Ejemplo n.º 7
0
    Ht = int(H * args.scale)
    Wt = int(W * args.scale)
    if abs(Ht / args.scale - H) > 1e-10 or abs(Wt / args.scale - W) > 1e-10:
        warnings.warn(
            "Inexact scaling, please check {} times ({}, {}) is integral".
            format(args.scale, H, W))
    H, W = Ht, Wt

net = make_model(conf["model"]).to(device=device)
net.load_weights(args)
if len(extra_gpus):
    warnings.warn("Multi GPU not implemented")

renderer = NeRFRenderer.from_conf(
    conf["renderer"],
    white_bkgd=not args.black,
    lindisp=dset.lindisp,
    eval_batch_size=args.ray_batch_size,
).to(device=device)

# Get the distance from camera to origin, for normalization of z when training.
# NOTE: we DO NOT actually need /camera location at test time.
# I am using canonical coordinates only for convenience in current implementation.
z_near = dset.z_near
z_far = dset.z_far

print("Generating rays")

dtu_format = hasattr(dset, "sub_format") and dset.sub_format == "dtu"

if dtu_format:
    print("Using DTU format")