def add_self_loops(edge_index, edge_weight: Optional[Var] = None, fill_value: float = 1., num_nodes: Optional[int] = None): r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted, self-loops will be added with edge weights denoted by :obj:`fill_value`. Args: edge_index (Var int32): The edge indices. edge_weight (Var, optional): One-dimensional edge weights. (default: :obj:`None`) fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`, will add self-loops with edge weights of :obj:`fill_value` to the graph. (default: :obj:`1.`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`Var int32`, :class:`Var`) """ N = maybe_num_nodes(edge_index, num_nodes) loop_index = jt.arange(0, N, dtype=Var.int32) loop_index = loop_index.unsqueeze(0).repeat(2, 1) if edge_weight is not None: assert edge_weight.numel() == edge_index.size(1) loop_weight = init.constant((N, ), edge_weight.dtype, fill_value) edge_weight = jt.concat([edge_weight, loop_weight], dim=0) edge_index = jt.concat([edge_index, loop_index], dim=1) return edge_index, edge_weight
def execute(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) spx = jt.split(out, self.width, 1) for i in range(self.nums): if i == 0 or self.stype == 'stage': sp = spx[i] else: sp = sp + spx[i] sp = self.convs[i](sp) sp = self.relu(self.bns[i](sp)) if i == 0: out = sp else: out = jt.concat((out, sp), 1) if self.scale != 1 and self.stype == 'normal': out = jt.concat((out, spx[self.nums]), 1) elif self.scale != 1 and self.stype == 'stage': out = jt.concat((out, self.pool(spx[self.nums])), 1) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out
def execute(self, x): input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1) h = input_pts for i, l in enumerate(self.pts_linears): h = self.pts_linears[i](h) h = jt.nn.relu(h) if i in self.skips: h = jt.concat([input_pts, h], -1) if self.use_viewdirs: alpha = self.alpha_linear(h) feature = self.feature_linear(h) h = jt.concat([feature, input_views], -1) for i, l in enumerate(self.views_linears): h = self.views_linears[i](h) h = jt.nn.relu(h) rgb = self.rgb_linear(h) outputs = jt.concat([rgb, alpha], -1) else: outputs = self.output_linear(h) return outputs
def execute(self, pcd, prev_s): """ Args: pcd: b, n, 3 prev_s: b, c, n """ b, n, _ = pcd.shape pcd_bcn = pcd.transpose(0, 2, 1) l0_xyz = pcd l0_points = pcd_bcn if self.if_noise: noise_points = init.gauss([b, 3, n], 'float', mean=0.0, std=self.noise_stdv) l0_points = jittor.concat([l0_points, noise_points], 1) l1_xyz, l1_points = self.sa_module_1(l0_xyz, l0_points) # b, 512, 128 (bnc) l2_xyz, l2_points = self.sa_module_2(l1_xyz, l1_points) l3_xyz, l3_points = self.sa_module_3(l2_xyz, l2_points) l2_points = self.fp_module_3(l2_xyz, l3_xyz, l2_points, l3_points) l2_points, prev_s['l2'] = self.unit_3(l2_points, prev_s['l2']) l1_points = self.fp_module_2(l1_xyz, l2_xyz, l1_points, l2_points) l1_points, prev_s['l1'] = self.unit_2(l1_points, prev_s['l1']) l0_points = self.fp_module_1(l0_xyz, l1_xyz, concat([pcd_bcn, pcd_bcn], dim=1), l1_points) l0_points, prev_s['l0'] = self.unit_1(l0_points, prev_s['l0']) # (B, 128, 2048) noise = init.gauss([b, 32, n], 'float', mean=0.0, std=1.0) feat = concat([l0_points, noise], dim=1) delta_xyz = self.tanh(self.mlp_conv(feat)) * 1.0 / 10 ** (self.step - 1) point_cloud = (pcd_bcn + delta_xyz).transpose(0, 2, 1) return point_cloud, delta_xyz
def sample_pdf(bins, weights, N_samples, det=False): # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / jt.sum(weights, -1, keepdims=True) cdf = jt.cumsum(pdf, -1) cdf = jt.concat([jt.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Take uniform samples if det: u = jt.linspace(0., 1., steps=N_samples) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) else: u = jt.random(list(cdf.shape[:-1]) + [N_samples]) # Invert CDF inds = jt.searchsorted(cdf, u, right=True) below = jt.maximum(jt.zeros_like(inds - 1), inds - 1) above = jt.minimum((cdf.shape[-1] - 1) * jt.ones_like(inds), inds) inds_g = jt.stack([below, above], -1) # (batch, N_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = jt.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = jt.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom[denom < 1e-5] = 1.0 t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples
def execute(self, x): out = self.conv1(nn.relu(self.bn1(x))) out = self.conv2(nn.relu(self.bn2(out))) out = jt.transpose(out, (1, 0, 2, 3)) x = jt.transpose(x, (1, 0, 2, 3)) out = jt.concat([out, x], 0) out = jt.transpose(out, (1, 0, 2, 3)) #out = jt.reshape(out, [x.shape[0],-1,out.shape[2],out.shape[3]]) return out
def sample(N_rays, N_samples, lindisp, perturb, near, far): t_vals = jt.linspace(0., 1., steps=N_samples) if not lindisp: z_vals = near * (1. - t_vals) + far * (t_vals) else: z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals)) z_vals = z_vals.expand([N_rays, N_samples]) if perturb > 0.: # get intervals between samples mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = jt.concat([mids, z_vals[..., -1:]], -1) lower = jt.concat([z_vals[..., :1], mids], -1) # stratified samples in those intervals t_rand = jt.random(z_vals.shape) z_vals = lower + (upper - lower) * t_rand return z_vals
def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None): r"""Removes the isolated nodes from the graph given by :attr:`edge_index` with optional edge attributes :attr:`edge_attr`. In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter out isolated node features later on. Self-loops are preserved for non-isolated nodes. Args: edge_index (Var int32): The edge indices. edge_attr (Var, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (Var int32, Var, Var bool) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) out = segregate_self_loops(edge_index, edge_attr) edge_index, edge_attr, loop_edge_index, loop_edge_attr = out mask = jt.zeros((num_nodes), dtype=Var.bool) mask[edge_index.view(-1)] = 1 assoc = jt.full((num_nodes, ), -1, dtype=Var.int32) assoc[mask] = jt.arange(mask.sum()) edge_index = assoc[edge_index] loop_mask = jt.zeros_like(mask) loop_mask[loop_edge_index[0]] = 1 loop_mask = loop_mask & mask loop_assoc = jt.full_like(assoc, -1) loop_assoc[loop_edge_index[0]] = jt.arange(loop_edge_index.size(1)) loop_idx = loop_assoc[loop_mask] loop_edge_index = assoc[loop_edge_index[:, loop_idx]] edge_index = jt.concat([edge_index, loop_edge_index], dim=1) if edge_attr is not None: loop_edge_attr = loop_edge_attr[loop_idx] edge_attr = jt.concat([edge_attr, loop_edge_attr], dim=0) return edge_index, edge_attr, mask
def integrator(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False): """Transforms model's predictions to semantically meaningful values. Args: raw: [num_rays, num_samples along ray, 4]. Prediction from model. z_vals: [num_rays, num_samples along ray]. Integration time. rays_d: [num_rays, 3]. Direction of each ray. Returns: rgb_map: [num_rays, 3]. Estimated RGB color of a ray. disp_map: [num_rays]. Disparity map. Inverse of depth map. acc_map: [num_rays]. Sum of weights along each ray. weights: [num_rays, num_samples]. Weights assigned to each sampled color. depth_map: [num_rays]. Estimated distance to object. """ raw2alpha = lambda raw, dists, act_fn=jt.nn.relu: 1. - jt.exp(-act_fn(raw) * dists) dists = z_vals[..., 1:] - z_vals[..., :-1] dists = jt.concat([ dists, jt.array(np.array([1e10]).astype(np.float32)).expand( dists[..., :1].shape) ], -1) # [N_rays, N_samples] dists = dists * jt.norm(rays_d.unsqueeze(-2), p=2, dim=-1) rgb = jt.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] noise = 0. if raw_noise_std > 0.: noise = jt.init.gauss(raw[..., 3].shape, raw.dtype) * raw_noise_std alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] weights = alpha * jt.cumprod( jt.concat([jt.ones( (alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:, :-1] rgb_map = jt.sum(weights.unsqueeze(-1) * rgb, -2) # [N_rays, 3] depth_map = jt.sum(weights * z_vals, -1) disp_map = 1. / jt.maximum(1e-10 * jt.ones_like(depth_map), depth_map / jt.sum(weights, -1)) acc_map = jt.sum(weights, -1) if white_bkgd: rgb_map = rgb_map + (1. - acc_map.unsqueeze(-1)) return rgb_map, disp_map, acc_map, weights, depth_map
def write_imgs(imgss, path, pad=1): hn = len(imgss) wn = len(imgss[0]) ch, h, w = imgss[0][0].shape h_ = (h + pad) * hn w_ = (w + pad) * wn out = [] vertical = jt.ones([3, h, pad]) horizontal = jt.ones([3, pad, (w + pad) + (wn - 1) * (imgss[0][1].shape[2] + pad)]) # horizontal[0] = 1 for i in range(hn): out_ = [] for j in range(wn): out_.append(imgss[i][j]) out_.append(vertical) # out[:, i * (h + pad) : i * (h + pad) + h, j * (w + pad) : j * (w + pad) + w] = imgss[i][j] out.append(jt.concat(out_, 2)) out.append(horizontal) out = jt.concat(out, 1) write_img(out, path)
def random_subsample(pcd, n_points=2048): """ Args: pcd: (B, N, 3) returns: new_pcd: (B, n_points, 3) """ b, n, _ = pcd.shape batch_idx = jittor.arange(b).reshape((-1, 1)).repeat(1, n_points) idx = jittor.concat([jittor.randperm(n)[:n_points].reshape((1, -1)) for i in range(b)], 0) return pcd[batch_idx, idx, :]
def laplace_mse(out, gt): channels = gt[IMAGE_KEY].shape[-1] if channels == 1: out_lap = laplace(out[OUT_KEY], out[IN_KEY]) elif channels == 3: out_lap = [] for i in range(channels): out_lap.append(laplace(out[OUT_KEY][..., i], out[IN_KEY])) out_lap = jittor.concat(out_lap, dim=-1) else: raise ValueError() return ((out_lap - gt[LAP_KEY])**2).mean()
def add_remaining_self_loops(edge_index, edge_weight: Optional[Var] = None, fill_value: float = 1., num_nodes: Optional[int] = None): r"""Adds remaining self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted and already contains a few self-loops, only non-existent self-loops will be added with edge weights denoted by :obj:`fill_value`. Args: edge_index (Var int32): The edge indices. edge_weight (Var, optional): One-dimensional edge weights. (default: :obj:`None`) fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`, will add self-loops with edge weights of :obj:`fill_value` to the graph. (default: :obj:`1.`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`Var int32`, :class:`Var`) """ N = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index[0], edge_index[1] mask = row != col loop_index = jt.arange(0, N, dtype=row.dtype) loop_index = loop_index.unsqueeze(0).repeat(2, 1) edge_index = jt.concat([edge_index[:, mask], loop_index], dim=1) if edge_weight is not None: inv_mask = jt.logical_not(mask) loop_weight = init.constant((N, ), edge_weight.dtype, fill_value) remaining_edge_weight = edge_weight[inv_mask] if remaining_edge_weight.numel() > 0: loop_weight[row[inv_mask]] = remaining_edge_weight edge_weight = jt.concat([edge_weight[mask], loop_weight], dim=0) return edge_index, edge_weight
def two_grad_mse(out, gt): channels = gt[IMAGE_KEY].shape[-1] if channels == 1: grads = jittor.grad(out[OUT_KEY], out[IN_KEY]) elif channels == 3: grads = [] for i in range(channels): grads.append(jittor.grad(out[OUT_KEY][..., i], out[IN_KEY])) grads = jittor.concat(grads, dim=-1) else: raise ValueError() return ((grads - gt[GRAD_KEY])**2).sum(-1).mean()
def contains_isolated_nodes(edge_index, num_nodes=None): r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains isolated nodes. Args: edge_index (Var int32): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: bool """ num_nodes = maybe_num_nodes(edge_index, num_nodes) (row, col), _ = remove_self_loops(edge_index) return jt.unique(jt.concat((row, col))).size(0) < num_nodes
def batchify_rays(rays_flat, chunk=1024 * 32, **kwargs): """Render rays in smaller minibatches to avoid OOM. """ all_ret = {} for i in range(0, rays_flat.shape[0], chunk): ret = render_rays(rays_flat[i:i + chunk], **kwargs) for k in ret: if k not in all_ret: all_ret[k] = [] all_ret[k].append(ret[k]) if jt.flags.no_grad: jt.sync_all() all_ret = {k: jt.concat(all_ret[k], 0) for k in all_ret} return all_ret
def upsample_cat(self, p1, p2, p3, p4): p1 = nn.interpolate(p1, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True) p2 = nn.interpolate(p2, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True) p3 = nn.interpolate(p3, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True) p4 = nn.interpolate(p4, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True) return jt.concat([p1, p2, p3, p4], dim=1)
def collate(data_list): r"""Collates a python list of data objects to the internal storage format of :class:`torch_geometric.data.InMemoryDataset`.""" keys = data_list[0].keys data = data_list[0].__class__() for key in keys: data[key] = [] slices = {key: [0] for key in keys} for item, key in product(data_list, keys): data[key].append(item[key]) if isinstance(item[key], Var) and item[key].ndim > 0: cat_dim = item.__cat_dim__(key, item[key]) cat_dim = 0 if cat_dim is None else cat_dim s = slices[key][-1] + item[key].size(cat_dim) else: s = slices[key][-1] + 1 slices[key].append(s) if hasattr(data_list[0], '__num_nodes__'): data.__num_nodes__ = [] for item in data_list: data.__num_nodes__.append(item.num_nodes) for key in keys: item = data_list[0][key] if isinstance(item, Var) and len(data_list) > 1: if item.ndim > 0: cat_dim = data.__cat_dim__(key, item) cat_dim = 0 if cat_dim is None else cat_dim data[key] = jt.concat(data[key], dim=cat_dim) else: data[key] = jt.stack(data[key]) elif isinstance(item, Var): # Don't duplicate attributes... data[key] = data[key][0] elif isinstance(item, int) or isinstance(item, float): data[key] = jt.array(data[key]) slices[key] = jt.array(slices[key], dtype=Var.int32) return data, slices
def execute(self, x): b, c, h, w = x.shape group_size = min(self.group_size, b) y = x.reshape([ group_size, -1, self.num_new_features, c // self.num_new_features, h, w ]) y = y - y.mean(0, keepdims=True) # y = (y ** 2).mean(0, keepdims=True) y = (y.sqr()).mean(0, keepdims=True) # y = (y + 1e-8) ** 0.5 y = (y + 1e-8).pow(0.5) y = y.mean([3, 4, 5], keepdims=True).squeeze( 3) # don't keep the meaned-out channels y = y.expand([group_size, y.size(1), y.size(2), h, w]).clone().reshape(b, self.num_new_features, h, w) # z = torch.cat([x, y], dim=1) z = jt.concat([x, y], dim=1) return z
def pack(self, batch_data): batch = {'img': [], 'label': []} # img tensor current shape: B,H,W,C all_same_height_images = [ self.process.resize_with_specific_height(_['img'][0].numpy()) for _ in batch_data ] max_img_w = max({m_img.shape[1] for m_img in all_same_height_images}) # make sure max_img_w is integral multiple of 8 max_img_w = int(np.ceil(max_img_w / 8) * 8) for i in range(len(batch_data)): _label = batch_data[i]['label'][0] img = self.process.normalize_img( self.process.width_pad_img(all_same_height_images[i], max_img_w)) img = img.transpose([2, 0, 1]) batch['img'].append(jittor.array(img, dtype=jittor.float)) batch['label'].append(_label) batch['img'] = jittor.concat(batch['img'], dim=0) return batch
def __call__(self, batch): resize_images = [] all_same_height_images = [ self.process.resize_with_specific_height(_['img']) for _ in batch ] max_img_w = max({m_img.shape[1] for m_img in all_same_height_images}) # make sure max_img_w is integral multiple of 8 max_img_w = int(np.ceil(max_img_w / 8) * 8) labels = [] for i in range(len(batch)): _label = batch[i]['label'] labels.append(_label) img = self.process.width_pad_img(all_same_height_images[i], max_img_w) img = self.process.normalize_img(img) img = img.transpose([2, 0, 1]) resize_images.append(jittor.array(img, dtype=jittor.float)) resize_images = jittor.concat(resize_images, dim=1) return {'img': resize_images, 'label': labels}
def draw_mask(img, ww): ch, h, w = img.shape fm_h, fm_w, h_, w_ = ww.shape center_fm_h = int(fm_h / 2) center_fm_w = int(fm_w / 2) # center_fm_h = 0 # center_fm_w = 0 h__ = int(h / fm_h) w__ = int(w / fm_w) mask = jt.zeros([ch, h, w]) for i_ in range(h_): for j_ in range(w_): i = center_fm_h + i_ - int(h_ / 2) j = center_fm_w + j_ - int(w_ / 2) if (i < 0 or j < 0 or i >= fm_h or j >= fm_w): continue mask[:, i * h__:(i+1) * h__, j * w__:(j+1) * w__] = ww[center_fm_h, center_fm_w, i_, j_] mask_img = img / 2 + mask * 2 # out = mask out = jt.concat([mask, mask_img], 2) return out
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024 * 64): """Prepares inputs and applies network 'fn'. """ inputs_flat = jt.reshape(inputs, [-1, inputs.shape[-1]]) embedded = embed_fn(inputs_flat) if viewdirs is not None: input_dirs = viewdirs[:, None].expand(inputs.shape) input_dirs_flat = jt.reshape(input_dirs, [-1, input_dirs.shape[-1]]) embedded_dirs = embeddirs_fn(input_dirs_flat) embedded = jt.concat([embedded, embedded_dirs], -1) outputs_flat = batchify(fn, netchunk)(embedded) outputs = jt.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) return outputs
def clip_grad_norm(self, max_norm:float, norm_type:int=2): r"""Clips gradient norm of this optimizer. The norm is computed over all gradients together. Args: max_norm (float or int): max norm of the gradients norm_type (int): 1-norm or 2-norm Example:: a = jt.ones(2) opt = jt.optim.SGD([a], 0.1) loss = a*a opt.zero_grad() opt.backward(loss) print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 opt.clip_grad_norm(0.01, 2) print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 opt.step() """ if self.__zero_grad: return grads = [] for pg in self.param_groups: for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue grads.append(g.flatten()) if len(grads) == 0: return total_norm = jt.norm(jt.concat(grads), norm_type) clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) for pg in self.param_groups: for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue g *= clip_coef
def style_mixing(generator, step, mean_style, n_source, n_target): source_code = jt.randn(n_source, 512) target_code = jt.randn(n_target, 512) shape = 4 * 2**step alpha = 1 images = [jt.ones((1, 3, shape, shape)) * -1] source_image = generator(source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7) target_image = generator(target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7) images.append(source_image) for i in range(n_target): image = generator( [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code], step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7, mixing_range=(0, 1), ) images.append(target_image[i].unsqueeze(0)) images.append(image) images = jt.concat(images, 0) return images
def render_rays(ray_batch, network_fn, network_query_fn, N_samples, retraw=False, lindisp=False, perturb=0., N_importance=0, network_fine=None, white_bkgd=False, raw_noise_std=0., verbose=False): """Volumetric rendering. Args: ray_batch: array of shape [batch_size, ...]. All information necessary for sampling along a ray, including: ray origin, ray direction, min dist, max dist, and unit-magnitude viewing direction. network_fn: function. Model for predicting RGB and density at each point in space. network_query_fn: function used for passing queries to network_fn. N_samples: int. Number of different times to sample along each ray. retraw: bool. If True, include model's raw, unprocessed predictions. lindisp: bool. If True, sample linearly in inverse depth rather than in depth. perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified random points in time. N_importance: int. Number of additional times to sample along each ray. These samples are only passed to network_fine. network_fine: "fine" network with same spec as network_fn. white_bkgd: bool. If True, assume a white background. raw_noise_std: ... verbose: bool. If True, print more debugging info. Returns: rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. disp_map: [num_rays]. Disparity map. 1 / depth. acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. raw: [num_rays, num_samples, 4]. Raw predictions from model. rgb0: See rgb_map. Output for coarse model. disp0: See disp_map. Output for coarse model. acc0: See acc_map. Output for coarse model. z_std: [num_rays]. Standard deviation of distances along ray for each sample. """ N_rays = ray_batch.shape[0] rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None bounds = jt.reshape(ray_batch[..., 6:8], [-1, 1, 2]) near, far = bounds[..., 0], bounds[..., 1] # [-1,1] z_vals = sample(N_rays, N_samples, lindisp, perturb, near, far) pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze( -1) # [N_rays, N_samples, 3] raw = network_query_fn(pts, viewdirs, network_fn) rgb_map, disp_map, acc_map, weights, depth_map = integrator( raw, z_vals, rays_d, raw_noise_std, white_bkgd) rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map #usefulRayIndex = jt.nonzero(acc_map > 0.1) if N_importance > 0: # importance sampling z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.)) z_samples = z_samples.detach() _, z_vals = jt.argsort(jt.concat([z_vals, z_samples], -1), -1) pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze( -1) # [N_rays, N_samples + N_importance, 3] run_fn = network_fn if network_fine is None else network_fine raw = network_query_fn(pts, viewdirs, run_fn) rgb_map, disp_map, acc_map, weights, depth_map = integrator( raw, z_vals, rays_d, raw_noise_std, white_bkgd) ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map} if retraw: ret['raw'] = raw if N_importance > 0: ret['rgb0'] = rgb_map_0 ret['disp0'] = disp_map_0 ret['acc0'] = acc_map_0 return ret
def write_img(img, path): img = img.permute([1,2,0]) * 255 img = jt.concat([img[:, :, 2:3], img[:, :, 1:2], img[:, :, 0:1]], 2) cv2.imwrite(path, img.data)
def render(H, W, focal, chunk=1024 * 32, rays=None, c2w=None, intrinsic=None, ndc=True, near=0., far=1., use_viewdirs=False, c2w_staticcam=None, **kwargs): """Render rays Args: H: int. Height of image in pixels. W: int. Width of image in pixels. focal: float. Focal length of pinhole camera. chunk: int. Maximum number of rays to process simultaneously. Used to control maximum memory usage. Does not affect final results. rays: array of shape [2, batch_size, 3]. Ray origin and direction for each example in batch. c2w: array of shape [3, 4]. Camera-to-world transformation matrix. ndc: bool. If True, represent ray origin, direction in NDC coordinates. near: float or array of shape [batch_size]. Nearest distance for a ray. far: float or array of shape [batch_size]. Farthest distance for a ray. use_viewdirs: bool. If True, use viewing direction of a point in space in model. c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for camera while using other c2w argument for viewing directions. Returns: rgb_map: [batch_size, 3]. Predicted RGB values for rays. disp_map: [batch_size]. Disparity map. Inverse of depth. acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. extras: dict with everything returned by render_rays(). """ if c2w is not None: # special case to render full image rays_o, rays_d = pinhole_get_rays(H, W, focal, c2w, intrinsic) else: # use provided ray batch rays_o, rays_d = rays if use_viewdirs: # provide ray directions as input viewdirs = rays_d if c2w_staticcam is not None: assert intrinsic is None rays_o, rays_d = pinhole_get_rays(H, W, focal, c2w_staticcam) viewdirs = viewdirs / jt.norm(viewdirs, p=2, dim=-1, keepdim=True) viewdirs = jt.reshape(viewdirs, [-1, 3]).float() sh = rays_d.shape # [..., 3] if ndc: # for forward facing scenes rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) # Create ray batch rays_o = jt.reshape(rays_o, [-1, 3]).float() rays_d = jt.reshape(rays_d, [-1, 3]).float() near, far = near * jt.ones_like(rays_d[..., :1]), far * jt.ones_like( rays_d[..., :1]) rays = jt.concat([rays_o, rays_d, near, far], -1) if use_viewdirs: rays = jt.concat([rays, viewdirs], -1) # Render and reshape all_ret = batchify_rays(rays, chunk, **kwargs) for k in all_ret: k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) all_ret[k] = jt.reshape(all_ret[k], k_sh) k_extract = ['rgb_map', 'disp_map', 'acc_map'] ret_list = [all_ret[k] for k in k_extract] ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} return ret_list + [ret_dict]
def get_laplacian(edge_index, edge_weight: Optional[Var] = None, normalization: Optional[str] = None, dtype: Optional[int] = None, num_nodes: Optional[int] = None): r""" Computes the graph Laplacian of the graph given by :obj:`edge_index` and optional :obj:`edge_weight`. Args: edge_index (Var int32): The edge indices. edge_weight (Var, optional): One-dimensional edge weights. (default: :obj:`None`) normalization (str, optional): The normalization scheme for the graph Laplacian (default: :obj:`None`): 1. :obj:`None`: No normalization :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 2. :obj:`"sym"`: Symmetric normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}` 3. :obj:`"rw"`: Random-walk normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` dtype (Var.dtype, optional): The desired data type of returned Var in case :obj:`edge_weight=None`. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) """ if normalization is not None: assert normalization in ['sym', 'rw'] # 'Invalid normalization' edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) if edge_weight is None: edge_weight = jt.ones((edge_index.size(1)), dtype=dtype) num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index[0], edge_index[1] shape = list(edge_weight.shape) shape[0] = num_nodes deg = jt.zeros(shape) deg = jt.scatter(deg, 0, row, src=edge_weight, reduce='add') if normalization is None: # L = D - A. edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) edge_weight = jt.concat([-edge_weight, deg], dim=0) elif normalization == 'sym': # Compute A_norm = -D^{-1/2} A D^{-1/2}. deg_inv_sqrt = deg.pow(-0.5) # deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0) for i in range(deg_inv_sqrt.shape[0]): if deg_inv_sqrt[i] == float('inf'): deg_inv_sqrt[i] = 0 edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] # L = I - A_norm. edge_index, tmp = add_self_loops(edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes) assert tmp is not None edge_weight = tmp else: # Compute A_norm = -D^{-1} A. deg_inv = 1.0 / deg deg_inv.masked_fill(deg_inv == float('inf'), 0) edge_weight = deg_inv[row] * edge_weight # L = I - A_norm. edge_index, tmp = add_self_loops(edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes) assert tmp is not None edge_weight = tmp return edge_index, edge_weight
def read_planetoid_data(folder, prefix): names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] items = [read_file(folder, prefix, name) for name in names] x, tx, allx, y, ty, ally, graph, test_index = items train_index = jt.arange(y.size(0), dtype=Var.int32) val_index = jt.arange(y.size(0), y.size(0) + 500, dtype=Var.int32) sorted_test_index = test_index.argsort()[1] if prefix.lower() == 'citeseer': # There are some isolated nodes in the Citeseer graph, resulting in # none consecutive test indices. We need to identify them and add them # as zero vectors to `tx` and `ty`. len_test_indices = (test_index.max() - test_index.min()).item() + 1 tx_ext = jt.zeros((len_test_indices, tx.size(1))) tx_ext[sorted_test_index - test_index.min(), :] = tx ty_ext = jt.zeros((len_test_indices, ty.size(1))) ty_ext[sorted_test_index - test_index.min(), :] = ty tx, ty = tx_ext, ty_ext if prefix.lower() == 'nell.0.001': tx_ext = jt.zeros((len(graph) - allx.size(0), x.size(1))) tx_ext[sorted_test_index - allx.size(0)] = tx ty_ext = jt.zeros((len(graph) - ally.size(0), y.size(1))) ty_ext[sorted_test_index - ally.size(0)] = ty tx, ty = tx_ext, ty_ext x = jt.concat([allx, tx], dim=0) x[test_index] = x[sorted_test_index] # Creating feature vectors for relations. row, col, value = SparseTensor.from_dense(x).coo() rows, cols, values = [row], [col], [value] mask1 = index_to_mask(test_index, size=len(graph)) mask2 = index_to_mask(jt.arange(allx.size(0), len(graph)), size=len(graph)) mask = jt.logical_or(jt.logical_not(mask1), jt.logical_not(mask2)) isolated_index = mask.nonzero(as_tuple=False).view(-1)[allx.size(0):] rows += [isolated_index] cols += [jt.arange(isolated_index.size(0)) + x.size(1)] values += [jt.ones((isolated_index.size(0)))] x = SparseTensor(row=jt.concat(rows), col=jt.concat(cols), value=jt.concat(values)) else: x = jt.concat([allx, tx], dim=0) x[test_index] = x[sorted_test_index] y = jt.concat([ally, ty], dim=0).argmax(dim=1)[0] y[test_index] = y[sorted_test_index] train_mask = index_to_mask(train_index, size=y.size(0)) val_mask = index_to_mask(val_index, size=y.size(0)) test_mask = index_to_mask(test_index, size=y.size(0)) edge_index = edge_index_from_dict(graph, num_nodes=y.size(0)) data = Data(x=x, edge_index=edge_index, y=y) data.train_mask = train_mask data.val_mask = val_mask data.test_mask = test_mask return data