def check(a): perms = list(permutations(range(a.ndim))) + [None] for perm in perms: with jt.log_capture_scope(log_silent=1, log_v=0, log_vprefix="op.cc=100") as raw_log: if perm: x = np.transpose(a, perm) y = jt.transpose(a, perm).data else: x = np.transpose(a) y = jt.transpose(a).data self.assertEqual(x.shape, y.shape) logs = find_log_with_re( raw_log, "(Jit op key (not )?found: " + "cutt_transpose" + ".*)") if perm is None: continue last = -1 in_order = True for i in range(len(perm)): if a.shape[perm[i]] == 1: continue if last != -1 and last > perm[i]: in_order = False break last = perm[i] if not in_order: assert len(logs) == 1 assert (x == y).all(), f"\n{x}\n{y}"
def execute(self, x): out = self.conv1(nn.relu(self.bn1(x))) out = self.conv2(nn.relu(self.bn2(out))) out = jt.transpose(out, (1, 0, 2, 3)) x = jt.transpose(x, (1, 0, 2, 3)) out = jt.concat([out, x], 0) out = jt.transpose(out, (1, 0, 2, 3)) #out = jt.reshape(out, [x.shape[0],-1,out.shape[2],out.shape[3]]) return out
def check(a): perms = list(permutations(range(a.ndim))) + [None] for perm in perms: if perm: x = np.transpose(a, perm) y = jt.transpose(a, perm).data else: x = np.transpose(a) y = jt.transpose(a).data self.assertEqual(x.shape, y.shape) assert (x == y).all(), f"\n{x}\n{y}"
def check(a): perms = list(permutations(range(a.ndim))) + [None] for perm in perms: x = jt.array(a).float() if perm: y = jt.transpose(x, perm) else: y = jt.transpose(x) dx = jt.grad(y * y, x).data self.assertEqual(dx.shape, a.shape) assert (dx == a * 2).all(), f"\n{dx}\n{a}\n{perm}"
def channel_shuffle(x, groups): (batchsize, num_channels, height, width) = x.data.shape channels_per_group = (num_channels // groups) x = jt.reshape(x, [batchsize, groups, channels_per_group, height, width]) x = jt.transpose(x, (0, 2, 1, 3, 4)) x = jt.reshape(x, [batchsize, (-1), height, width]) return x
def __call__(self, plane=None, quat=None, weight=1): reg_rot = jt.transform.to_tensor(jt.array([0])) reg_plane = jt.transform.to_tensor(jt.array([0])) if plane: p = [normalize(i[:, 0:3]).unsqueeze(2) for i in plane] x = jt.contrib.concat(p, dim=2) # y = jt.transpose(x, [1,2]) y = jt.transpose(x, [0, 2, 1]) reg_plane = ( (jt.matmul(x, y) - self.eye).pow(2).sum(2).sum(1).mean() * weight) if quat: q = [i[:, 1:4].unsqueeze(2) for i in quat] x = jt.contrib.concat(q, dim=2) y = jt.transpose(x, [0, 2, 1]) reg_rot = ( (jt.matmul(x, y) - self.eye).pow(2).sum(2).sum(1).mean() * weight) return (reg_plane, reg_rot)
def execute(self, conv4_3_feats, conv7_feats, conv8_2_feats, conv9_2_feats, conv10_2_feats, conv11_2_feats): """ Forward propagation. Args: conv4_3_feats: conv4_3 feature map, a array of dimensions (N, 512, 38, 38) conv7_feats: conv7 feature map, a array of dimensions (N, 1024, 19, 19) conv8_2_feats: conv8_2 feature map, a array of dimensions (N, 512, 10, 10) conv9_2_feats: conv9_2 feature map, a array of dimensions (N, 256, 5, 5) conv10_2_feats: conv10_2 feature map, a array of dimensions (N, 256, 3, 3) conv11_2_feats: conv11_2 feature map, a array of dimensions (N, 256, 1, 1) Return: 8732 locations and class scores (i.e. w.r.t each prior box) for each image """ batch_size = conv4_3_feats.shape[0] l_conv4_3 = self.loc_conv4_3(conv4_3_feats) l_conv4_3 = jt.transpose(l_conv4_3, [0, 2, 3, 1]) l_conv4_3 = jt.reshape(l_conv4_3, [batch_size, -1, 4]) l_conv7 = self.loc_conv7(conv7_feats) l_conv7 = jt.transpose(l_conv7, [0, 2, 3, 1]) l_conv7 = jt.reshape(l_conv7, [batch_size, -1, 4]) l_conv8_2 = self.loc_conv8_2(conv8_2_feats) l_conv8_2 = jt.transpose(l_conv8_2, [0, 2, 3, 1]) l_conv8_2 = jt.reshape(l_conv8_2, [batch_size, -1, 4]) l_conv9_2 = self.loc_conv9_2(conv9_2_feats) l_conv9_2 = jt.transpose(l_conv9_2, [0, 2, 3, 1]) l_conv9_2 = jt.reshape(l_conv9_2, [batch_size, -1, 4]) l_conv10_2 = self.loc_conv10_2(conv10_2_feats) l_conv10_2 = jt.transpose(l_conv10_2, [0, 2, 3, 1]) l_conv10_2 = jt.reshape(l_conv10_2, [batch_size, -1, 4]) l_conv11_2 = self.loc_conv11_2(conv11_2_feats) l_conv11_2 = jt.transpose(l_conv11_2, [0, 2, 3, 1]) l_conv11_2 = jt.reshape(l_conv11_2, [batch_size, -1, 4]) c_conv4_3 = self.cl_conv4_3(conv4_3_feats) c_conv4_3 = jt.transpose(c_conv4_3, [0, 2, 3, 1]) c_conv4_3 = jt.reshape(c_conv4_3, [batch_size, -1, self.n_classes]) c_conv7 = self.cl_conv7(conv7_feats) c_conv7 = jt.transpose(c_conv7, [0, 2, 3, 1]) c_conv7 = jt.reshape(c_conv7, [batch_size, -1, self.n_classes]) c_conv8_2 = self.cl_conv8_2(conv8_2_feats) c_conv8_2 = jt.transpose(c_conv8_2, [0, 2, 3, 1]) c_conv8_2 = jt.reshape(c_conv8_2, [batch_size, -1, self.n_classes]) c_conv9_2 = self.cl_conv9_2(conv9_2_feats) c_conv9_2 = jt.transpose(c_conv9_2, [0, 2, 3, 1]) c_conv9_2 = jt.reshape(c_conv9_2, [batch_size, -1, self.n_classes]) c_conv10_2 = self.cl_conv10_2(conv10_2_feats) c_conv10_2 = jt.transpose(c_conv10_2, [0, 2, 3, 1]) c_conv10_2 = jt.reshape(c_conv10_2, [batch_size, -1, self.n_classes]) c_conv11_2 = self.cl_conv11_2(conv11_2_feats) c_conv11_2 = jt.transpose(c_conv11_2, [0, 2, 3, 1]) c_conv11_2 = jt.reshape(c_conv11_2, [batch_size, -1, self.n_classes]) locs = jt.contrib.concat( [l_conv4_3, l_conv7, l_conv8_2, l_conv9_2, l_conv10_2, l_conv11_2], dim=1) classes_scores = jt.contrib.concat( [c_conv4_3, c_conv7, c_conv8_2, c_conv9_2, c_conv10_2, c_conv11_2], dim=1) return (locs, classes_scores)
def train(): parser = config_parser() args = parser.parse_args() # Load data intrinsic = None if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) print('DEFINING BOUNDS') if args.no_ndc: near = np.ndarray.min(bds) * .9 far = np.ndarray.max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) elif args.dataset_type == 'blender': testskip = args.testskip faketestskip = args.faketestskip if jt.mpi and jt.mpi.local_rank() != 0: testskip = faketestskip faketestskip = 1 if args.do_intrinsic: images, poses, intrinsic, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip, args.blender_factor, True) else: images, poses, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip, args.blender_factor) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split i_test_tot = i_test i_test = i_test[::args.faketestskip] near = args.near far = args.far print(args.do_intrinsic) print("hwf", hwf) print("near", near) print("far", far) if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data( scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R - 1. far = hemi_R + 1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) global_step = start bds_dict = { 'near': near, 'far': far, } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Move testing data to GPU render_poses = jt.array(render_poses) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') with jt.no_grad(): testsavedir = os.path.join( basedir, expname, 'renderonly_{}_{:06d}'.format( 'test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Prepare raybatch tensor if batching random rays accumulation_steps = 1 N_rand = args.N_rand // accumulation_steps use_batching = not args.no_batching if use_batching: # For random ray batching print('get rays') rays = np.stack( [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] print('done, concats') rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) print('done') i_batch = 0 # Move training data to GPU images = jt.array(images.astype(np.float32)) poses = jt.array(poses) if use_batching: rays_rgb = jt.array(rays_rgb) N_iters = 51000 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) if not jt.mpi or jt.mpi.local_rank() == 0: date = str(datetime.datetime.now()) date = date[:date.rfind(":")].replace("-", "")\ .replace(":", "")\ .replace(" ", "_") gpu_idx = os.environ.get("CUDA_VISIBLE_DEVICES", "0") log_dir = os.path.join("./logs", "summaries", "log_" + date + "_gpu" + gpu_idx) if not os.path.exists(log_dir): os.makedirs(log_dir) writer = SummaryWriter(log_dir=log_dir) start = start + 1 for i in trange(start, N_iters): # jt.display_memory_info() time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1, 3*?] batch = jt.transpose(batch, (1, 0, 2)) batch_rays, target_s = batch[:2], batch[2] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: print("Shuffle data after an epoch!") rand_idx = jt.randperm(rays_rgb.shape[0]) rays_rgb = rays_rgb[rand_idx] i_batch = 0 else: # Random from one image np.random.seed(i) img_i = np.random.choice(i_train) target = images[img_i] #.squeeze(0) pose = poses[img_i, :3, :4] #.squeeze(0) if N_rand is not None: rays_o, rays_d = pinhole_get_rays( H, W, focal, pose, intrinsic) # (H, W, 3), (H, W, 3) if i < args.precrop_iters: dH = int(H // 2 * args.precrop_frac) dW = int(W // 2 * args.precrop_frac) coords = jt.stack( jt.meshgrid( jt.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), jt.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)), -1) if i == start: print( f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}" ) else: coords = jt.stack( jt.meshgrid(jt.linspace(0, H - 1, H), jt.linspace(0, W - 1, W)), -1) # (H, W, 2) coords = jt.reshape(coords, [-1, 2]) # (H * W, 2) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) select_coords = coords[select_inds].int() # (N_rand, 2) rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) batch_rays = jt.stack([rays_o, rays_d], 0) target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) ##### Core optimization loop ##### rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) img_loss = img2mse(rgb, target_s) trans = extras['raw'][..., -1] loss = img_loss psnr = mse2psnr(img_loss) if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0) optimizer.backward(loss / accumulation_steps) if i % accumulation_steps == 0: optimizer.step() ### update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * accumulation_steps * 1000 new_lrate = args.lrate * (decay_rate**(global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate ################################ dt = time.time() - time0 # Rest is logging if (i + 1) % args.i_weights == 0 and (not jt.mpi or jt.mpi.local_rank() == 0): print(i) path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) jt.save( { 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), }, path) print('Saved checkpoints at', path) if i % args.i_video == 0 and i > 0: # Turn on testing mode with jt.no_grad(): rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test, intrinsic=intrinsic) if not jt.mpi or jt.mpi.local_rank() == 0: print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join( basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) print('movie base ', moviebase) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if i % args.i_print == 0: tqdm.write( f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") if i % args.i_img == 0: img_i = np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3, :4] with jt.no_grad(): rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, intrinsic=intrinsic, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) rgb = rgb.numpy() disp = disp.numpy() acc = acc.numpy() if not jt.mpi or jt.mpi.local_rank() == 0: writer.add_image('test/rgb', to8b(rgb), global_step, dataformats="HWC") writer.add_image('test/target', target.numpy(), global_step, dataformats="HWC") writer.add_scalar('test/psnr', psnr.item(), global_step) jt.clean_graph() jt.sync_all() jt.gc() if i % args.i_testset == 0 and i > 0: si_test = i_test_tot if i % args.i_tottest == 0 else i_test testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[si_test].shape) with jt.no_grad(): rgbs, disps = render_path(jt.array(poses[si_test]), hwf, args.chunk, render_kwargs_test, savedir=testsavedir, intrinsic=intrinsic, expname=expname) jt.gc() global_step += 1