Exemplo n.º 1
0
 def check(a):
     perms = list(permutations(range(a.ndim))) + [None]
     for perm in perms:
         with jt.log_capture_scope(log_silent=1,
                                   log_v=0,
                                   log_vprefix="op.cc=100") as raw_log:
             if perm:
                 x = np.transpose(a, perm)
                 y = jt.transpose(a, perm).data
             else:
                 x = np.transpose(a)
                 y = jt.transpose(a).data
             self.assertEqual(x.shape, y.shape)
         logs = find_log_with_re(
             raw_log,
             "(Jit op key (not )?found: " + "cutt_transpose" + ".*)")
         if perm is None:
             continue
         last = -1
         in_order = True
         for i in range(len(perm)):
             if a.shape[perm[i]] == 1:
                 continue
             if last != -1 and last > perm[i]:
                 in_order = False
                 break
             last = perm[i]
         if not in_order:
             assert len(logs) == 1
         assert (x == y).all(), f"\n{x}\n{y}"
Exemplo n.º 2
0
 def execute(self, x):
     out = self.conv1(nn.relu(self.bn1(x)))
     out = self.conv2(nn.relu(self.bn2(out)))
     out = jt.transpose(out, (1, 0, 2, 3))
     x = jt.transpose(x, (1, 0, 2, 3))
     out = jt.concat([out, x], 0)
     out = jt.transpose(out, (1, 0, 2, 3))
     #out = jt.reshape(out, [x.shape[0],-1,out.shape[2],out.shape[3]])
     return out
Exemplo n.º 3
0
 def check(a):
     perms = list(permutations(range(a.ndim))) + [None]
     for perm in perms:
         if perm:
             x = np.transpose(a, perm)
             y = jt.transpose(a, perm).data
         else:
             x = np.transpose(a)
             y = jt.transpose(a).data
         self.assertEqual(x.shape, y.shape)
         assert (x == y).all(), f"\n{x}\n{y}"
Exemplo n.º 4
0
 def check(a):
     perms = list(permutations(range(a.ndim))) + [None]
     for perm in perms:
         x = jt.array(a).float()
         if perm:
             y = jt.transpose(x, perm)
         else:
             y = jt.transpose(x)
         dx = jt.grad(y * y, x).data
         self.assertEqual(dx.shape, a.shape)
         assert (dx == a * 2).all(), f"\n{dx}\n{a}\n{perm}"
Exemplo n.º 5
0
def channel_shuffle(x, groups):
    (batchsize, num_channels, height, width) = x.data.shape
    channels_per_group = (num_channels // groups)
    x = jt.reshape(x, [batchsize, groups, channels_per_group, height, width])
    x = jt.transpose(x, (0, 2, 1, 3, 4))
    x = jt.reshape(x, [batchsize, (-1), height, width])
    return x
Exemplo n.º 6
0
 def __call__(self, plane=None, quat=None, weight=1):
     reg_rot = jt.transform.to_tensor(jt.array([0]))
     reg_plane = jt.transform.to_tensor(jt.array([0]))
     if plane:
         p = [normalize(i[:, 0:3]).unsqueeze(2) for i in plane]
         x = jt.contrib.concat(p, dim=2)
         # y = jt.transpose(x, [1,2])
         y = jt.transpose(x, [0, 2, 1])
         reg_plane = (
             (jt.matmul(x, y) - self.eye).pow(2).sum(2).sum(1).mean() *
             weight)
     if quat:
         q = [i[:, 1:4].unsqueeze(2) for i in quat]
         x = jt.contrib.concat(q, dim=2)
         y = jt.transpose(x, [0, 2, 1])
         reg_rot = (
             (jt.matmul(x, y) - self.eye).pow(2).sum(2).sum(1).mean() *
             weight)
     return (reg_plane, reg_rot)
Exemplo n.º 7
0
    def execute(self, conv4_3_feats, conv7_feats, conv8_2_feats, conv9_2_feats,
                conv10_2_feats, conv11_2_feats):
        """ Forward propagation.

        Args:
            conv4_3_feats: conv4_3 feature map, a array of dimensions (N, 512, 38, 38)
            conv7_feats: conv7 feature map, a array of dimensions (N, 1024, 19, 19)
            conv8_2_feats: conv8_2 feature map, a array of dimensions (N, 512, 10, 10)
            conv9_2_feats: conv9_2 feature map, a array of dimensions (N, 256, 5, 5)
            conv10_2_feats: conv10_2 feature map, a array of dimensions (N, 256, 3, 3)
            conv11_2_feats: conv11_2 feature map, a array of dimensions (N, 256, 1, 1)
        Return: 
            8732 locations and class scores (i.e. w.r.t each prior box) for each image
        """
        batch_size = conv4_3_feats.shape[0]
        l_conv4_3 = self.loc_conv4_3(conv4_3_feats)
        l_conv4_3 = jt.transpose(l_conv4_3, [0, 2, 3, 1])
        l_conv4_3 = jt.reshape(l_conv4_3, [batch_size, -1, 4])
        l_conv7 = self.loc_conv7(conv7_feats)
        l_conv7 = jt.transpose(l_conv7, [0, 2, 3, 1])
        l_conv7 = jt.reshape(l_conv7, [batch_size, -1, 4])
        l_conv8_2 = self.loc_conv8_2(conv8_2_feats)
        l_conv8_2 = jt.transpose(l_conv8_2, [0, 2, 3, 1])
        l_conv8_2 = jt.reshape(l_conv8_2, [batch_size, -1, 4])
        l_conv9_2 = self.loc_conv9_2(conv9_2_feats)
        l_conv9_2 = jt.transpose(l_conv9_2, [0, 2, 3, 1])
        l_conv9_2 = jt.reshape(l_conv9_2, [batch_size, -1, 4])
        l_conv10_2 = self.loc_conv10_2(conv10_2_feats)
        l_conv10_2 = jt.transpose(l_conv10_2, [0, 2, 3, 1])
        l_conv10_2 = jt.reshape(l_conv10_2, [batch_size, -1, 4])
        l_conv11_2 = self.loc_conv11_2(conv11_2_feats)
        l_conv11_2 = jt.transpose(l_conv11_2, [0, 2, 3, 1])
        l_conv11_2 = jt.reshape(l_conv11_2, [batch_size, -1, 4])
        c_conv4_3 = self.cl_conv4_3(conv4_3_feats)
        c_conv4_3 = jt.transpose(c_conv4_3, [0, 2, 3, 1])
        c_conv4_3 = jt.reshape(c_conv4_3, [batch_size, -1, self.n_classes])
        c_conv7 = self.cl_conv7(conv7_feats)
        c_conv7 = jt.transpose(c_conv7, [0, 2, 3, 1])
        c_conv7 = jt.reshape(c_conv7, [batch_size, -1, self.n_classes])
        c_conv8_2 = self.cl_conv8_2(conv8_2_feats)
        c_conv8_2 = jt.transpose(c_conv8_2, [0, 2, 3, 1])
        c_conv8_2 = jt.reshape(c_conv8_2, [batch_size, -1, self.n_classes])
        c_conv9_2 = self.cl_conv9_2(conv9_2_feats)
        c_conv9_2 = jt.transpose(c_conv9_2, [0, 2, 3, 1])
        c_conv9_2 = jt.reshape(c_conv9_2, [batch_size, -1, self.n_classes])
        c_conv10_2 = self.cl_conv10_2(conv10_2_feats)
        c_conv10_2 = jt.transpose(c_conv10_2, [0, 2, 3, 1])
        c_conv10_2 = jt.reshape(c_conv10_2, [batch_size, -1, self.n_classes])
        c_conv11_2 = self.cl_conv11_2(conv11_2_feats)
        c_conv11_2 = jt.transpose(c_conv11_2, [0, 2, 3, 1])
        c_conv11_2 = jt.reshape(c_conv11_2, [batch_size, -1, self.n_classes])
        locs = jt.contrib.concat(
            [l_conv4_3, l_conv7, l_conv8_2, l_conv9_2, l_conv10_2, l_conv11_2],
            dim=1)
        classes_scores = jt.contrib.concat(
            [c_conv4_3, c_conv7, c_conv8_2, c_conv9_2, c_conv10_2, c_conv11_2],
            dim=1)
        return (locs, classes_scores)
Exemplo n.º 8
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    # Load data
    intrinsic = None
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.

        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        testskip = args.testskip
        faketestskip = args.faketestskip
        if jt.mpi and jt.mpi.local_rank() != 0:
            testskip = faketestskip
            faketestskip = 1
        if args.do_intrinsic:
            images, poses, intrinsic, render_poses, hwf, i_split = load_blender_data(
                args.datadir, args.half_res, args.testskip,
                args.blender_factor, True)
        else:
            images, poses, render_poses, hwf, i_split = load_blender_data(
                args.datadir, args.half_res, args.testskip,
                args.blender_factor)
        print('Loaded blender', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split
        i_test_tot = i_test
        i_test = i_test[::args.faketestskip]

        near = args.near
        far = args.far
        print(args.do_intrinsic)
        print("hwf", hwf)
        print("near", near)
        print("far", far)

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip)

        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.
        far = hemi_R + 1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)
    global_step = start

    bds_dict = {
        'near': near,
        'far': far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = jt.array(render_poses)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        with jt.no_grad():
            testsavedir = os.path.join(
                basedir, expname, 'renderonly_{}_{:06d}'.format(
                    'test' if args.render_test else 'path', start))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            rgbs, _ = render_path(render_poses,
                                  hwf,
                                  args.chunk,
                                  render_kwargs_test,
                                  savedir=testsavedir,
                                  render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                             to8b(rgbs),
                             fps=30,
                             quality=8)

            return

    # Prepare raybatch tensor if batching random rays
    accumulation_steps = 1
    N_rand = args.N_rand // accumulation_steps
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching
        print('get rays')
        rays = np.stack(
            [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]],
            0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:, None]],
                                  1)  # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb,
                                [0, 2, 3, 1, 4])  # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            0)  # train images only
        rays_rgb = np.reshape(rays_rgb,
                              [-1, 3, 3])  # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)

        print('done')
        i_batch = 0

    # Move training data to GPU
    images = jt.array(images.astype(np.float32))
    poses = jt.array(poses)
    if use_batching:
        rays_rgb = jt.array(rays_rgb)

    N_iters = 51000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
    if not jt.mpi or jt.mpi.local_rank() == 0:
        date = str(datetime.datetime.now())
        date = date[:date.rfind(":")].replace("-", "")\
                                        .replace(":", "")\
                                        .replace(" ", "_")
        gpu_idx = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
        log_dir = os.path.join("./logs", "summaries",
                               "log_" + date + "_gpu" + gpu_idx)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        writer = SummaryWriter(log_dir=log_dir)

    start = start + 1
    for i in trange(start, N_iters):
        # jt.display_memory_info()
        time0 = time.time()

        # Sample random ray batch
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch + N_rand]  # [B, 2+1, 3*?]
            batch = jt.transpose(batch, (1, 0, 2))
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = jt.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0

        else:
            # Random from one image
            np.random.seed(i)
            img_i = np.random.choice(i_train)
            target = images[img_i]  #.squeeze(0)
            pose = poses[img_i, :3, :4]  #.squeeze(0)
            if N_rand is not None:
                rays_o, rays_d = pinhole_get_rays(
                    H, W, focal, pose, intrinsic)  # (H, W, 3), (H, W, 3)
                if i < args.precrop_iters:
                    dH = int(H // 2 * args.precrop_frac)
                    dW = int(W // 2 * args.precrop_frac)
                    coords = jt.stack(
                        jt.meshgrid(
                            jt.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH),
                            jt.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)),
                        -1)
                    if i == start:
                        print(
                            f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}"
                        )
                else:
                    coords = jt.stack(
                        jt.meshgrid(jt.linspace(0, H - 1, H),
                                    jt.linspace(0, W - 1, W)), -1)  # (H, W, 2)

                coords = jt.reshape(coords, [-1, 2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0],
                                               size=[N_rand],
                                               replace=False)  # (N_rand,)
                select_coords = coords[select_inds].int()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = jt.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0],
                                  select_coords[:, 1]]  # (N_rand, 3)

        #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H,
                                        W,
                                        focal,
                                        chunk=args.chunk,
                                        rays=batch_rays,
                                        verbose=i < 10,
                                        retraw=True,
                                        **render_kwargs_train)
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][..., -1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        optimizer.backward(loss / accumulation_steps)
        if i % accumulation_steps == 0:
            optimizer.step()

        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * accumulation_steps * 1000
        new_lrate = args.lrate * (decay_rate**(global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time() - time0

        # Rest is logging
        if (i + 1) % args.i_weights == 0 and (not jt.mpi
                                              or jt.mpi.local_rank() == 0):
            print(i)
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            jt.save(
                {
                    'global_step':
                    global_step,
                    'network_fn_state_dict':
                    render_kwargs_train['network_fn'].state_dict(),
                    'network_fine_state_dict':
                    render_kwargs_train['network_fine'].state_dict(),
                }, path)
            print('Saved checkpoints at', path)

        if i % args.i_video == 0 and i > 0:
            # Turn on testing mode
            with jt.no_grad():
                rgbs, disps = render_path(render_poses,
                                          hwf,
                                          args.chunk,
                                          render_kwargs_test,
                                          intrinsic=intrinsic)
            if not jt.mpi or jt.mpi.local_rank() == 0:
                print('Done, saving', rgbs.shape, disps.shape)
                moviebase = os.path.join(
                    basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
                print('movie base ', moviebase)
                imageio.mimwrite(moviebase + 'rgb.mp4',
                                 to8b(rgbs),
                                 fps=30,
                                 quality=8)
                imageio.mimwrite(moviebase + 'disp.mp4',
                                 to8b(disps / np.max(disps)),
                                 fps=30,
                                 quality=8)

        if i % args.i_print == 0:
            tqdm.write(
                f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
            if i % args.i_img == 0:
                img_i = np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3, :4]
                with jt.no_grad():
                    rgb, disp, acc, extras = render(H,
                                                    W,
                                                    focal,
                                                    chunk=args.chunk,
                                                    c2w=pose,
                                                    intrinsic=intrinsic,
                                                    **render_kwargs_test)
                psnr = mse2psnr(img2mse(rgb, target))
                rgb = rgb.numpy()
                disp = disp.numpy()
                acc = acc.numpy()

                if not jt.mpi or jt.mpi.local_rank() == 0:
                    writer.add_image('test/rgb',
                                     to8b(rgb),
                                     global_step,
                                     dataformats="HWC")
                    writer.add_image('test/target',
                                     target.numpy(),
                                     global_step,
                                     dataformats="HWC")
                    writer.add_scalar('test/psnr', psnr.item(), global_step)

            jt.clean_graph()
            jt.sync_all()
            jt.gc()

            if i % args.i_testset == 0 and i > 0:
                si_test = i_test_tot if i % args.i_tottest == 0 else i_test
                testsavedir = os.path.join(basedir, expname,
                                           'testset_{:06d}'.format(i))
                os.makedirs(testsavedir, exist_ok=True)
                print('test poses shape', poses[si_test].shape)
                with jt.no_grad():
                    rgbs, disps = render_path(jt.array(poses[si_test]),
                                              hwf,
                                              args.chunk,
                                              render_kwargs_test,
                                              savedir=testsavedir,
                                              intrinsic=intrinsic,
                                              expname=expname)
                jt.gc()
        global_step += 1