Exemple #1
0
def add_self_loops(edge_index,
                   edge_weight: Optional[Var] = None,
                   fill_value: float = 1.,
                   num_nodes: Optional[int] = None):
    r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node
    :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
    In case the graph is weighted, self-loops will be added with edge weights
    denoted by :obj:`fill_value`.

    Args:
        edge_index (Var int32): The edge indices.
        edge_weight (Var, optional): One-dimensional edge weights.
            (default: :obj:`None`)
        fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`,
            will add self-loops with edge weights of :obj:`fill_value` to the
            graph. (default: :obj:`1.`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: (:class:`Var int32`, :class:`Var`)
    """
    N = maybe_num_nodes(edge_index, num_nodes)

    loop_index = jt.arange(0, N, dtype=Var.int32)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)

    if edge_weight is not None:
        assert edge_weight.numel() == edge_index.size(1)
        loop_weight = init.constant((N, ), edge_weight.dtype, fill_value)
        edge_weight = jt.concat([edge_weight, loop_weight], dim=0)

    edge_index = jt.concat([edge_index, loop_index], dim=1)

    return edge_index, edge_weight
Exemple #2
0
    def execute(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        spx = jt.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0 or self.stype == 'stage':
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.relu(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                out = jt.concat((out, sp), 1)
        if self.scale != 1 and self.stype == 'normal':
            out = jt.concat((out, spx[self.nums]), 1)
        elif self.scale != 1 and self.stype == 'stage':
            out = jt.concat((out, self.pool(spx[self.nums])), 1)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
Exemple #3
0
    def execute(self, x):
        input_pts, input_views = jt.split(x,
                                          [self.input_ch, self.input_ch_views],
                                          dim=-1)
        h = input_pts
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = jt.nn.relu(h)
            if i in self.skips:
                h = jt.concat([input_pts, h], -1)

        if self.use_viewdirs:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            h = jt.concat([feature, input_views], -1)

            for i, l in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = jt.nn.relu(h)

            rgb = self.rgb_linear(h)
            outputs = jt.concat([rgb, alpha], -1)
        else:
            outputs = self.output_linear(h)

        return outputs
Exemple #4
0
    def execute(self, pcd, prev_s):
        """
        Args:
            pcd: b, n, 3
            prev_s: b, c, n
        """
        b, n, _ = pcd.shape
        pcd_bcn = pcd.transpose(0, 2, 1)
        l0_xyz = pcd
        l0_points = pcd_bcn
        if self.if_noise:
            noise_points = init.gauss([b, 3, n], 'float', mean=0.0, std=self.noise_stdv)
            l0_points = jittor.concat([l0_points, noise_points], 1)
        l1_xyz, l1_points = self.sa_module_1(l0_xyz, l0_points)  # b, 512, 128 (bnc)
        l2_xyz, l2_points = self.sa_module_2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa_module_3(l2_xyz, l2_points)

        l2_points = self.fp_module_3(l2_xyz, l3_xyz, l2_points, l3_points)
        l2_points, prev_s['l2'] = self.unit_3(l2_points, prev_s['l2'])

        l1_points = self.fp_module_2(l1_xyz, l2_xyz, l1_points, l2_points)
        l1_points, prev_s['l1'] = self.unit_2(l1_points, prev_s['l1'])

        l0_points = self.fp_module_1(l0_xyz, l1_xyz, concat([pcd_bcn, pcd_bcn], dim=1), l1_points)
        l0_points, prev_s['l0'] = self.unit_1(l0_points, prev_s['l0'])  # (B, 128, 2048)

        noise = init.gauss([b, 32, n], 'float', mean=0.0, std=1.0)
        feat = concat([l0_points, noise], dim=1)
        delta_xyz = self.tanh(self.mlp_conv(feat)) * 1.0 / 10 ** (self.step - 1)
        point_cloud = (pcd_bcn + delta_xyz).transpose(0, 2, 1)
        return point_cloud, delta_xyz
Exemple #5
0
def sample_pdf(bins, weights, N_samples, det=False):
    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / jt.sum(weights, -1, keepdims=True)
    cdf = jt.cumsum(pdf, -1)
    cdf = jt.concat([jt.zeros_like(cdf[..., :1]), cdf],
                    -1)  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = jt.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = jt.random(list(cdf.shape[:-1]) + [N_samples])

    # Invert CDF
    inds = jt.searchsorted(cdf, u, right=True)
    below = jt.maximum(jt.zeros_like(inds - 1), inds - 1)
    above = jt.minimum((cdf.shape[-1] - 1) * jt.ones_like(inds), inds)
    inds_g = jt.stack([below, above], -1)  # (batch, N_samples, 2)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = jt.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = jt.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom[denom < 1e-5] = 1.0
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
 def execute(self, x):
     out = self.conv1(nn.relu(self.bn1(x)))
     out = self.conv2(nn.relu(self.bn2(out)))
     out = jt.transpose(out, (1, 0, 2, 3))
     x = jt.transpose(x, (1, 0, 2, 3))
     out = jt.concat([out, x], 0)
     out = jt.transpose(out, (1, 0, 2, 3))
     #out = jt.reshape(out, [x.shape[0],-1,out.shape[2],out.shape[3]])
     return out
Exemple #7
0
def sample(N_rays, N_samples, lindisp, perturb, near, far):
    t_vals = jt.linspace(0., 1., steps=N_samples)
    if not lindisp:
        z_vals = near * (1. - t_vals) + far * (t_vals)
    else:
        z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals))

    z_vals = z_vals.expand([N_rays, N_samples])

    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = jt.concat([mids, z_vals[..., -1:]], -1)
        lower = jt.concat([z_vals[..., :1], mids], -1)
        # stratified samples in those intervals
        t_rand = jt.random(z_vals.shape)
        z_vals = lower + (upper - lower) * t_rand

    return z_vals
Exemple #8
0
def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None):
    r"""Removes the isolated nodes from the graph given by :attr:`edge_index`
    with optional edge attributes :attr:`edge_attr`.
    In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter
    out isolated node features later on.
    Self-loops are preserved for non-isolated nodes.

    Args:
        edge_index (Var int32): The edge indices.
        edge_attr (Var, optional): Edge weights or multi-dimensional
            edge features. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: (Var int32, Var, Var bool)
    """
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    out = segregate_self_loops(edge_index, edge_attr)
    edge_index, edge_attr, loop_edge_index, loop_edge_attr = out

    mask = jt.zeros((num_nodes), dtype=Var.bool)
    mask[edge_index.view(-1)] = 1

    assoc = jt.full((num_nodes, ), -1, dtype=Var.int32)
    assoc[mask] = jt.arange(mask.sum())
    edge_index = assoc[edge_index]

    loop_mask = jt.zeros_like(mask)
    loop_mask[loop_edge_index[0]] = 1
    loop_mask = loop_mask & mask
    loop_assoc = jt.full_like(assoc, -1)
    loop_assoc[loop_edge_index[0]] = jt.arange(loop_edge_index.size(1))
    loop_idx = loop_assoc[loop_mask]
    loop_edge_index = assoc[loop_edge_index[:, loop_idx]]

    edge_index = jt.concat([edge_index, loop_edge_index], dim=1)

    if edge_attr is not None:
        loop_edge_attr = loop_edge_attr[loop_idx]
        edge_attr = jt.concat([edge_attr, loop_edge_attr], dim=0)

    return edge_index, edge_attr, mask
Exemple #9
0
def integrator(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: [num_rays, num_samples along ray, 4]. Prediction from model.
        z_vals: [num_rays, num_samples along ray]. Integration time.
        rays_d: [num_rays, 3]. Direction of each ray.
    Returns:
        rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
        disp_map: [num_rays]. Disparity map. Inverse of depth map.
        acc_map: [num_rays]. Sum of weights along each ray.
        weights: [num_rays, num_samples]. Weights assigned to each sampled color.
        depth_map: [num_rays]. Estimated distance to object.
    """
    raw2alpha = lambda raw, dists, act_fn=jt.nn.relu: 1. - jt.exp(-act_fn(raw)
                                                                  * dists)

    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = jt.concat([
        dists,
        jt.array(np.array([1e10]).astype(np.float32)).expand(
            dists[..., :1].shape)
    ], -1)  # [N_rays, N_samples]
    dists = dists * jt.norm(rays_d.unsqueeze(-2), p=2, dim=-1)

    rgb = jt.sigmoid(raw[..., :3])  # [N_rays, N_samples, 3]
    noise = 0.
    if raw_noise_std > 0.:
        noise = jt.init.gauss(raw[..., 3].shape, raw.dtype) * raw_noise_std
    alpha = raw2alpha(raw[..., 3] + noise, dists)  # [N_rays, N_samples]
    weights = alpha * jt.cumprod(
        jt.concat([jt.ones(
            (alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = jt.sum(weights.unsqueeze(-1) * rgb, -2)  # [N_rays, 3]

    depth_map = jt.sum(weights * z_vals, -1)
    disp_map = 1. / jt.maximum(1e-10 * jt.ones_like(depth_map),
                               depth_map / jt.sum(weights, -1))
    acc_map = jt.sum(weights, -1)

    if white_bkgd:
        rgb_map = rgb_map + (1. - acc_map.unsqueeze(-1))

    return rgb_map, disp_map, acc_map, weights, depth_map
Exemple #10
0
def write_imgs(imgss, path, pad=1):
    hn = len(imgss)
    wn = len(imgss[0])
    ch, h, w = imgss[0][0].shape
    h_ = (h + pad) * hn
    w_ = (w + pad) * wn
    out = []
    vertical = jt.ones([3, h, pad])
    horizontal = jt.ones([3, pad, (w + pad) + (wn - 1) * (imgss[0][1].shape[2] + pad)])
    # horizontal[0] = 1
    for i in range(hn):
        out_ = []
        for j in range(wn):
            out_.append(imgss[i][j])
            out_.append(vertical)
            # out[:, i * (h + pad) : i * (h + pad) + h, j * (w + pad) : j * (w + pad) + w] = imgss[i][j]
        out.append(jt.concat(out_, 2))
        out.append(horizontal)
    out = jt.concat(out, 1)
    write_img(out, path)
Exemple #11
0
def random_subsample(pcd, n_points=2048):
    """
    Args:
        pcd: (B, N, 3)

    returns:
        new_pcd: (B, n_points, 3)
    """
    b, n, _ = pcd.shape
    batch_idx = jittor.arange(b).reshape((-1, 1)).repeat(1, n_points)
    idx = jittor.concat([jittor.randperm(n)[:n_points].reshape((1, -1)) for i in range(b)], 0)
    return pcd[batch_idx, idx, :]
def laplace_mse(out, gt):
    channels = gt[IMAGE_KEY].shape[-1]

    if channels == 1:
        out_lap = laplace(out[OUT_KEY], out[IN_KEY])
    elif channels == 3:
        out_lap = []
        for i in range(channels):
            out_lap.append(laplace(out[OUT_KEY][..., i], out[IN_KEY]))
        out_lap = jittor.concat(out_lap, dim=-1)
    else:
        raise ValueError()
    return ((out_lap - gt[LAP_KEY])**2).mean()
Exemple #13
0
def add_remaining_self_loops(edge_index,
                             edge_weight: Optional[Var] = None,
                             fill_value: float = 1.,
                             num_nodes: Optional[int] = None):
    r"""Adds remaining self-loop :math:`(i,i) \in \mathcal{E}` to every node
    :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
    In case the graph is weighted and already contains a few self-loops, only
    non-existent self-loops will be added with edge weights denoted by
    :obj:`fill_value`.

    Args:
        edge_index (Var int32): The edge indices.
        edge_weight (Var, optional): One-dimensional edge weights.
            (default: :obj:`None`)
        fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`,
            will add self-loops with edge weights of :obj:`fill_value` to the
            graph. (default: :obj:`1.`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: (:class:`Var int32`, :class:`Var`)
    """
    N = maybe_num_nodes(edge_index, num_nodes)
    row, col = edge_index[0], edge_index[1]
    mask = row != col

    loop_index = jt.arange(0, N, dtype=row.dtype)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)
    edge_index = jt.concat([edge_index[:, mask], loop_index], dim=1)

    if edge_weight is not None:
        inv_mask = jt.logical_not(mask)
        loop_weight = init.constant((N, ), edge_weight.dtype, fill_value)
        remaining_edge_weight = edge_weight[inv_mask]
        if remaining_edge_weight.numel() > 0:
            loop_weight[row[inv_mask]] = remaining_edge_weight
        edge_weight = jt.concat([edge_weight[mask], loop_weight], dim=0)

    return edge_index, edge_weight
def two_grad_mse(out, gt):
    channels = gt[IMAGE_KEY].shape[-1]

    if channels == 1:
        grads = jittor.grad(out[OUT_KEY], out[IN_KEY])
    elif channels == 3:
        grads = []
        for i in range(channels):
            grads.append(jittor.grad(out[OUT_KEY][..., i], out[IN_KEY]))
        grads = jittor.concat(grads, dim=-1)
    else:
        raise ValueError()
    return ((grads - gt[GRAD_KEY])**2).sum(-1).mean()
Exemple #15
0
def contains_isolated_nodes(edge_index, num_nodes=None):
    r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains
    isolated nodes.

    Args:
        edge_index (Var int32): The edge indices.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: bool
    """
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    (row, col), _ = remove_self_loops(edge_index)

    return jt.unique(jt.concat((row, col))).size(0) < num_nodes
Exemple #16
0
def batchify_rays(rays_flat, chunk=1024 * 32, **kwargs):
    """Render rays in smaller minibatches to avoid OOM.
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i + chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])
        if jt.flags.no_grad:
            jt.sync_all()

    all_ret = {k: jt.concat(all_ret[k], 0) for k in all_ret}
    return all_ret
 def upsample_cat(self, p1, p2, p3, p4):
     p1 = nn.interpolate(p1,
                         size=(self.numAngle, self.numRho),
                         mode='bilinear',
                         align_corners=True)
     p2 = nn.interpolate(p2,
                         size=(self.numAngle, self.numRho),
                         mode='bilinear',
                         align_corners=True)
     p3 = nn.interpolate(p3,
                         size=(self.numAngle, self.numRho),
                         mode='bilinear',
                         align_corners=True)
     p4 = nn.interpolate(p4,
                         size=(self.numAngle, self.numRho),
                         mode='bilinear',
                         align_corners=True)
     return jt.concat([p1, p2, p3, p4], dim=1)
Exemple #18
0
    def collate(data_list):
        r"""Collates a python list of data objects to the internal storage
        format of :class:`torch_geometric.data.InMemoryDataset`."""
        keys = data_list[0].keys
        data = data_list[0].__class__()

        for key in keys:
            data[key] = []
        slices = {key: [0] for key in keys}

        for item, key in product(data_list, keys):
            data[key].append(item[key])
            if isinstance(item[key], Var) and item[key].ndim > 0:
                cat_dim = item.__cat_dim__(key, item[key])
                cat_dim = 0 if cat_dim is None else cat_dim
                s = slices[key][-1] + item[key].size(cat_dim)
            else:
                s = slices[key][-1] + 1
            slices[key].append(s)

        if hasattr(data_list[0], '__num_nodes__'):
            data.__num_nodes__ = []
            for item in data_list:
                data.__num_nodes__.append(item.num_nodes)

        for key in keys:
            item = data_list[0][key]
            if isinstance(item, Var) and len(data_list) > 1:
                if item.ndim > 0:
                    cat_dim = data.__cat_dim__(key, item)
                    cat_dim = 0 if cat_dim is None else cat_dim
                    data[key] = jt.concat(data[key], dim=cat_dim)
                else:
                    data[key] = jt.stack(data[key])
            elif isinstance(item, Var):  # Don't duplicate attributes...
                data[key] = data[key][0]
            elif isinstance(item, int) or isinstance(item, float):
                data[key] = jt.array(data[key])

            slices[key] = jt.array(slices[key], dtype=Var.int32)

        return data, slices
 def execute(self, x):
     b, c, h, w = x.shape
     group_size = min(self.group_size, b)
     y = x.reshape([
         group_size, -1, self.num_new_features, c // self.num_new_features,
         h, w
     ])
     y = y - y.mean(0, keepdims=True)
     # y = (y ** 2).mean(0, keepdims=True)
     y = (y.sqr()).mean(0, keepdims=True)
     # y = (y + 1e-8) ** 0.5
     y = (y + 1e-8).pow(0.5)
     y = y.mean([3, 4, 5], keepdims=True).squeeze(
         3)  # don't keep the meaned-out channels
     y = y.expand([group_size, y.size(1),
                   y.size(2), h,
                   w]).clone().reshape(b, self.num_new_features, h, w)
     # z = torch.cat([x, y], dim=1)
     z = jt.concat([x, y], dim=1)
     return z
Exemple #20
0
 def pack(self, batch_data):
     batch = {'img': [], 'label': []}
     # img tensor current shape: B,H,W,C
     all_same_height_images = [
         self.process.resize_with_specific_height(_['img'][0].numpy())
         for _ in batch_data
     ]
     max_img_w = max({m_img.shape[1] for m_img in all_same_height_images})
     # make sure max_img_w is integral multiple of 8
     max_img_w = int(np.ceil(max_img_w / 8) * 8)
     for i in range(len(batch_data)):
         _label = batch_data[i]['label'][0]
         img = self.process.normalize_img(
             self.process.width_pad_img(all_same_height_images[i],
                                        max_img_w))
         img = img.transpose([2, 0, 1])
         batch['img'].append(jittor.array(img, dtype=jittor.float))
         batch['label'].append(_label)
     batch['img'] = jittor.concat(batch['img'], dim=0)
     return batch
Exemple #21
0
    def __call__(self, batch):
        resize_images = []

        all_same_height_images = [
            self.process.resize_with_specific_height(_['img']) for _ in batch
        ]
        max_img_w = max({m_img.shape[1] for m_img in all_same_height_images})
        # make sure max_img_w is integral multiple of 8
        max_img_w = int(np.ceil(max_img_w / 8) * 8)
        labels = []
        for i in range(len(batch)):
            _label = batch[i]['label']
            labels.append(_label)
            img = self.process.width_pad_img(all_same_height_images[i],
                                             max_img_w)
            img = self.process.normalize_img(img)
            img = img.transpose([2, 0, 1])
            resize_images.append(jittor.array(img, dtype=jittor.float))
        resize_images = jittor.concat(resize_images, dim=1)
        return {'img': resize_images, 'label': labels}
Exemple #22
0
def draw_mask(img, ww):
    ch, h, w = img.shape
    fm_h, fm_w, h_, w_ = ww.shape
    center_fm_h = int(fm_h / 2)
    center_fm_w = int(fm_w / 2)
    # center_fm_h = 0
    # center_fm_w = 0
    h__ = int(h / fm_h)
    w__ = int(w / fm_w)
    mask = jt.zeros([ch, h, w])
    for i_ in range(h_):
        for j_ in range(w_):
            i = center_fm_h + i_ - int(h_ / 2)
            j = center_fm_w + j_ - int(w_ / 2)
            if (i < 0 or j < 0 or i >= fm_h or j >= fm_w):
                continue
            mask[:, i * h__:(i+1) * h__, j * w__:(j+1) * w__] = ww[center_fm_h, center_fm_w, i_, j_]
    mask_img = img / 2 + mask * 2
    # out = mask
    out = jt.concat([mask, mask_img], 2)
    return out
Exemple #23
0
def run_network(inputs,
                viewdirs,
                fn,
                embed_fn,
                embeddirs_fn,
                netchunk=1024 * 64):
    """Prepares inputs and applies network 'fn'.
    """
    inputs_flat = jt.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)

    if viewdirs is not None:
        input_dirs = viewdirs[:, None].expand(inputs.shape)
        input_dirs_flat = jt.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = jt.concat([embedded, embedded_dirs], -1)

    outputs_flat = batchify(fn, netchunk)(embedded)
    outputs = jt.reshape(outputs_flat,
                         list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs
Exemple #24
0
    def clip_grad_norm(self, max_norm:float, norm_type:int=2):
        r"""Clips gradient norm of this optimizer.
        The norm is computed over all gradients together.

        Args:
            max_norm (float or int): max norm of the gradients
            norm_type (int): 1-norm or 2-norm

        Example::

            a = jt.ones(2)
            opt = jt.optim.SGD([a], 0.1)

            loss = a*a
            opt.zero_grad()
            opt.backward(loss)

            print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83
            opt.clip_grad_norm(0.01, 2)
            print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01
            
            opt.step()

        """
        if self.__zero_grad: return
        grads = []
        for pg in self.param_groups:
            for p, g in zip(pg["params"], pg["grads"]):
                if p.is_stop_grad(): continue
                grads.append(g.flatten())
        if len(grads) == 0: return
        total_norm = jt.norm(jt.concat(grads), norm_type)
        clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0)
        for pg in self.param_groups:
            for p, g in zip(pg["params"], pg["grads"]):
                if p.is_stop_grad(): continue
                g *= clip_coef
Exemple #25
0
def style_mixing(generator, step, mean_style, n_source, n_target):
    source_code = jt.randn(n_source, 512)
    target_code = jt.randn(n_target, 512)

    shape = 4 * 2**step
    alpha = 1

    images = [jt.ones((1, 3, shape, shape)) * -1]

    source_image = generator(source_code,
                             step=step,
                             alpha=alpha,
                             mean_style=mean_style,
                             style_weight=0.7)
    target_image = generator(target_code,
                             step=step,
                             alpha=alpha,
                             mean_style=mean_style,
                             style_weight=0.7)

    images.append(source_image)

    for i in range(n_target):
        image = generator(
            [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
            step=step,
            alpha=alpha,
            mean_style=mean_style,
            style_weight=0.7,
            mixing_range=(0, 1),
        )
        images.append(target_image[i].unsqueeze(0))
        images.append(image)

    images = jt.concat(images, 0)

    return images
Exemple #26
0
def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False):
    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn: function used for passing queries to network_fn.
      N_samples: int. Number of different times to sample along each ray.
      retraw: bool. If True, include model's raw, unprocessed predictions.
      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      white_bkgd: bool. If True, assume a white background.
      raw_noise_std: ...
      verbose: bool. If True, print more debugging info.
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disp_map: [num_rays]. Disparity map. 1 / depth.
      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disp_map. Output for coarse model.
      acc0: See acc_map. Output for coarse model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6]  # [N_rays, 3] each
    viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
    bounds = jt.reshape(ray_batch[..., 6:8], [-1, 1, 2])
    near, far = bounds[..., 0], bounds[..., 1]  # [-1,1]

    z_vals = sample(N_rays, N_samples, lindisp, perturb, near, far)
    pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(
        -1)  # [N_rays, N_samples, 3]

    raw = network_query_fn(pts, viewdirs, network_fn)
    rgb_map, disp_map, acc_map, weights, depth_map = integrator(
        raw, z_vals, rays_d, raw_noise_std, white_bkgd)

    rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
    #usefulRayIndex = jt.nonzero(acc_map > 0.1)
    if N_importance > 0:
        # importance sampling
        z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
        z_samples = sample_pdf(z_vals_mid,
                               weights[..., 1:-1],
                               N_importance,
                               det=(perturb == 0.))
        z_samples = z_samples.detach()

        _, z_vals = jt.argsort(jt.concat([z_vals, z_samples], -1), -1)
        pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(
            -1)  # [N_rays, N_samples + N_importance, 3]

        run_fn = network_fn if network_fine is None else network_fine
        raw = network_query_fn(pts, viewdirs, run_fn)
        rgb_map, disp_map, acc_map, weights, depth_map = integrator(
            raw, z_vals, rays_d, raw_noise_std, white_bkgd)

    ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0

    return ret
Exemple #27
0
def write_img(img, path):
    img = img.permute([1,2,0]) * 255
    img = jt.concat([img[:, :, 2:3], img[:, :, 1:2], img[:, :, 0:1]], 2)
    cv2.imwrite(path, img.data)
Exemple #28
0
def render(H,
           W,
           focal,
           chunk=1024 * 32,
           rays=None,
           c2w=None,
           intrinsic=None,
           ndc=True,
           near=0.,
           far=1.,
           use_viewdirs=False,
           c2w_staticcam=None,
           **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = pinhole_get_rays(H, W, focal, c2w, intrinsic)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            assert intrinsic is None
            rays_o, rays_d = pinhole_get_rays(H, W, focal, c2w_staticcam)
        viewdirs = viewdirs / jt.norm(viewdirs, p=2, dim=-1, keepdim=True)
        viewdirs = jt.reshape(viewdirs, [-1, 3]).float()

    sh = rays_d.shape  # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)

    # Create ray batch
    rays_o = jt.reshape(rays_o, [-1, 3]).float()
    rays_d = jt.reshape(rays_d, [-1, 3]).float()

    near, far = near * jt.ones_like(rays_d[..., :1]), far * jt.ones_like(
        rays_d[..., :1])
    rays = jt.concat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = jt.concat([rays, viewdirs], -1)

    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = jt.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]
Exemple #29
0
def get_laplacian(edge_index,
                  edge_weight: Optional[Var] = None,
                  normalization: Optional[str] = None,
                  dtype: Optional[int] = None,
                  num_nodes: Optional[int] = None):
    r""" Computes the graph Laplacian of the graph given by :obj:`edge_index`
    and optional :obj:`edge_weight`.

    Args:
        edge_index (Var int32): The edge indices.
        edge_weight (Var, optional): One-dimensional edge weights.
            (default: :obj:`None`)
        normalization (str, optional): The normalization scheme for the graph
            Laplacian (default: :obj:`None`):

            1. :obj:`None`: No normalization
            :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`

            2. :obj:`"sym"`: Symmetric normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
            \mathbf{D}^{-1/2}`

            3. :obj:`"rw"`: Random-walk normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
        dtype (Var.dtype, optional): The desired data type of returned Var
            in case :obj:`edge_weight=None`. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
    """

    if normalization is not None:
        assert normalization in ['sym', 'rw']  # 'Invalid normalization'

    edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)

    if edge_weight is None:
        edge_weight = jt.ones((edge_index.size(1)), dtype=dtype)

    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    row, col = edge_index[0], edge_index[1]
    shape = list(edge_weight.shape)
    shape[0] = num_nodes
    deg = jt.zeros(shape)
    deg = jt.scatter(deg, 0, row, src=edge_weight, reduce='add')
    if normalization is None:
        # L = D - A.
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        edge_weight = jt.concat([-edge_weight, deg], dim=0)
    elif normalization == 'sym':
        # Compute A_norm = -D^{-1/2} A D^{-1/2}.
        deg_inv_sqrt = deg.pow(-0.5)
        # deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0)

        for i in range(deg_inv_sqrt.shape[0]):
            if deg_inv_sqrt[i] == float('inf'):
                deg_inv_sqrt[i] = 0
        edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

        # L = I - A_norm.
        edge_index, tmp = add_self_loops(edge_index,
                                         -edge_weight,
                                         fill_value=1.,
                                         num_nodes=num_nodes)
        assert tmp is not None
        edge_weight = tmp
    else:
        # Compute A_norm = -D^{-1} A.
        deg_inv = 1.0 / deg
        deg_inv.masked_fill(deg_inv == float('inf'), 0)
        edge_weight = deg_inv[row] * edge_weight

        # L = I - A_norm.
        edge_index, tmp = add_self_loops(edge_index,
                                         -edge_weight,
                                         fill_value=1.,
                                         num_nodes=num_nodes)
        assert tmp is not None
        edge_weight = tmp

    return edge_index, edge_weight
Exemple #30
0
def read_planetoid_data(folder, prefix):
    names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
    items = [read_file(folder, prefix, name) for name in names]
    x, tx, allx, y, ty, ally, graph, test_index = items
    train_index = jt.arange(y.size(0), dtype=Var.int32)
    val_index = jt.arange(y.size(0), y.size(0) + 500, dtype=Var.int32)
    sorted_test_index = test_index.argsort()[1]

    if prefix.lower() == 'citeseer':
        # There are some isolated nodes in the Citeseer graph, resulting in
        # none consecutive test indices. We need to identify them and add them
        # as zero vectors to `tx` and `ty`.
        len_test_indices = (test_index.max() - test_index.min()).item() + 1

        tx_ext = jt.zeros((len_test_indices, tx.size(1)))
        tx_ext[sorted_test_index - test_index.min(), :] = tx
        ty_ext = jt.zeros((len_test_indices, ty.size(1)))
        ty_ext[sorted_test_index - test_index.min(), :] = ty

        tx, ty = tx_ext, ty_ext

    if prefix.lower() == 'nell.0.001':
        tx_ext = jt.zeros((len(graph) - allx.size(0), x.size(1)))
        tx_ext[sorted_test_index - allx.size(0)] = tx

        ty_ext = jt.zeros((len(graph) - ally.size(0), y.size(1)))
        ty_ext[sorted_test_index - ally.size(0)] = ty

        tx, ty = tx_ext, ty_ext

        x = jt.concat([allx, tx], dim=0)
        x[test_index] = x[sorted_test_index]

        # Creating feature vectors for relations.
        row, col, value = SparseTensor.from_dense(x).coo()
        rows, cols, values = [row], [col], [value]

        mask1 = index_to_mask(test_index, size=len(graph))
        mask2 = index_to_mask(jt.arange(allx.size(0), len(graph)),
                              size=len(graph))
        mask = jt.logical_or(jt.logical_not(mask1), jt.logical_not(mask2))
        isolated_index = mask.nonzero(as_tuple=False).view(-1)[allx.size(0):]

        rows += [isolated_index]
        cols += [jt.arange(isolated_index.size(0)) + x.size(1)]
        values += [jt.ones((isolated_index.size(0)))]

        x = SparseTensor(row=jt.concat(rows),
                         col=jt.concat(cols),
                         value=jt.concat(values))
    else:
        x = jt.concat([allx, tx], dim=0)
        x[test_index] = x[sorted_test_index]
    y = jt.concat([ally, ty], dim=0).argmax(dim=1)[0]
    y[test_index] = y[sorted_test_index]

    train_mask = index_to_mask(train_index, size=y.size(0))
    val_mask = index_to_mask(val_index, size=y.size(0))
    test_mask = index_to_mask(test_index, size=y.size(0))

    edge_index = edge_index_from_dict(graph, num_nodes=y.size(0))

    data = Data(x=x, edge_index=edge_index, y=y)
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    return data