def findLR(self, model, optimizer, writer,
               start_lr=1e-7, end_lr=10, num_iters=50):
        model.train()

        losses = []
        lrs = np.logspace(np.log10(start_lr), np.log10(end_lr), num_iters)

        for lr in lrs:
            # Update LR
            for group in optimizer.param_groups: group['lr'] = lr

            batch = next(iter(self.data_loaders[0]))
            input_images, depthGT, maskGT = utils.unpack_batch_fixed(batch, self.cfg.device)
            # ------ define ground truth------
            XGT, YGT = torch.meshgrid([torch.arange(self.cfg.outH), # [H,W]
                                       torch.arange(self.cfg.outW)]) # [H,W]
            XGT, YGT = XGT.float(), YGT.float()
            XYGT = torch.cat([
                XGT.repeat([self.cfg.outViewN, 1, 1]), 
                YGT.repeat([self.cfg.outViewN, 1, 1])], dim=0) #[2V,H,W]
            XYGT = XYGT.unsqueeze(dim=0).to(self.cfg.device) #[1,2V,H,W]

            with torch.set_grad_enabled(True):
                optimizer.zero_grad()

                XYZ, maskLogit = model(input_images)
                XY = XYZ[:, :self.cfg.outViewN * 2, :, :]
                depth = XYZ[:, self.cfg.outViewN * 2:self.cfg.outViewN * 3, :,  :]
                mask = (maskLogit > 0).byte()
                # ------ Compute loss ------
                loss_XYZ = self.l1(XY, XYGT)
                loss_XYZ += self.l1(depth.masked_select(mask),
                                    depthGT.masked_select(mask))
                loss_mask = self.sigmoid_bce(maskLogit, maskGT)
                loss = loss_mask + self.cfg.lambdaDepth * loss_XYZ

                # Update weights
                loss.backward()
                # True Weight decay
                if self.cfg.trueWD is not None:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data = param.data.add(
                                -self.cfg.trueWD * group['lr'], param.data)
                optimizer.step()

            losses.append(loss.item())

        fig, ax = plt.subplots()
        ax.plot(lrs, losses)
        ax.set_xlabel('learning rate')
        ax.set_ylabel('loss')
        ax.set_xscale('log')
        writer.add_figure('findLR', fig)
    def _val_on_epoch(self, model):
        model.eval()

        data_loader = self.data_loaders[1]
        running_loss_XYZ = 0.0
        running_loss_mask = 0.0
        running_loss = 0.0

        for batch in data_loader:
            input_images, depthGT, maskGT = utils.unpack_batch_fixed(batch, self.cfg.device)
            # ------ define ground truth------
            XGT, YGT = torch.meshgrid([
                torch.arange(self.cfg.outH), # [H,W]
                torch.arange(self.cfg.outW)]) # [H,W]
            XGT, YGT = XGT.float(), YGT.float()
            XYGT = torch.cat([
                XGT.repeat([self.cfg.outViewN, 1, 1]), 
                YGT.repeat([self.cfg.outViewN, 1, 1])], dim=0) #[2V,H,W]
            XYGT = XYGT.unsqueeze(dim=0).to(self.cfg.device) # [1,2V,H,W] 

            with torch.set_grad_enabled(False):
                XYZ, maskLogit = model(input_images)
                XY = XYZ[:, :self.cfg.outViewN * 2, :, :]
                depth = XYZ[:, self.cfg.outViewN * 2:self.cfg.outViewN*3,:,:]
                mask = (maskLogit > 0).byte()
                # ------ Compute loss ------
                loss_XYZ = self.l1(XY, XYGT)
                loss_XYZ += self.l1(depth.masked_select(mask),
                                    depthGT.masked_select(mask))
                loss_mask = self.sigmoid_bce(maskLogit, maskGT)
                loss = loss_mask + self.cfg.lambdaDepth * loss_XYZ

            running_loss_XYZ += loss_XYZ.item() * input_images.size(0)
            running_loss_mask += loss_mask.item() * input_images.size(0)
            running_loss += loss.item() * input_images.size(0)

        epoch_loss_XYZ = running_loss_XYZ / len(data_loader.dataset)
        epoch_loss_mask = running_loss_mask / len(data_loader.dataset)
        epoch_loss = running_loss / len(data_loader.dataset)

        print(f"\tVal loss: {epoch_loss}")
        return {"epoch_loss_XYZ": epoch_loss_XYZ,
                "epoch_loss_mask": epoch_loss_mask,
                "epoch_loss": epoch_loss, }
    def grid_anchors(self, grid_sizes):
        anchors = []
        for size, stride, base_anchors in zip(
            grid_sizes, self.strides, self.cell_anchors
        ):
            grid_height, grid_width = size
            device = base_anchors.device
            shifts_x = torch.arange(
                0, grid_width * stride, step=stride, dtype=torch.float32, device=device
            )
            shifts_y = torch.arange(
                0, grid_height * stride, step=stride, dtype=torch.float32, device=device
            )
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            shift_x = shift_x.reshape(-1)
            shift_y = shift_y.reshape(-1)
            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

            anchors.append(
                (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
            )

        return anchors
Esempio n. 4
0
 def get_grid_coords_3d(self, z, y, x, coord_dim=-1):
     z, y, x = torch.meshgrid(z, y, x)
     coords = torch.stack([x, y, z], dim=coord_dim)
     return coords
Esempio n. 5
0
 def _make_grid(nx=20, ny=20):
     yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
     return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
Esempio n. 6
0
    def __call__(self, pred_hm, pred_wh, pred_centerness, heatmap, box_target,
                 centerness, wh_weight, hm_weight):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            pred_centerness: tensor or None, (batch, 1, h, w).
            heatmap: tensor, (batch, 80, h, w).
            box_target: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w).
            centerness: tensor or None, (batch, 1, h, w).
            wh_weight: tensor or None, (batch, 80, h, w).

        Returns:

        """
        if every_n_local_step(100):
            pred_hm_summary = torch.clamp(torch.sigmoid(pred_hm),
                                          min=1e-4,
                                          max=1 - 1e-4)
            gt_hm_summary = heatmap.clone()
            if self.fovea_hm:
                if not self.only_merge:
                    pred_ctn_summary = torch.clamp(
                        torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4)
                    add_feature_summary(
                        'centernet/centerness',
                        pred_ctn_summary.detach().cpu().numpy(),
                        type='f')
                    add_feature_summary(
                        'centernet/merge',
                        (pred_ctn_summary *
                         pred_hm_summary).detach().cpu().numpy(),
                        type='max')

                add_feature_summary('centernet/gt_centerness',
                                    centerness.detach().cpu().numpy(),
                                    type='f')
                add_feature_summary('centernet/gt_merge',
                                    (centerness *
                                     gt_hm_summary).detach().cpu().numpy(),
                                    type='max')

            add_feature_summary('centernet/heatmap',
                                pred_hm_summary.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                gt_hm_summary.detach().cpu().numpy())

        H, W = pred_hm.shape[2:]
        if not self.fovea_hm:
            pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
            hm_weight = None if self.ct_version else hm_weight
            hm_loss = ct_focal_loss(pred_hm, heatmap,
                                    hm_weight=hm_weight) * self.hm_weight
            centerness_loss = hm_loss.new_tensor([0.])
            merge_loss = hm_loss.new_tensor([0.])
        else:
            care_mask = (heatmap >= 0).float()
            avg_factor = torch.sum(heatmap > 0).float().item() + 1e-6
            if not self.only_merge:
                hm_loss = py_sigmoid_focal_loss(
                    pred_hm, heatmap, care_mask,
                    reduction='sum') / avg_factor * self.hm_weight

                pred_centerness = torch.clamp(torch.sigmoid(pred_centerness),
                                              min=1e-4,
                                              max=1 - 1e-4)
                centerness_loss = ct_focal_loss(
                    pred_centerness, centerness, gamma=2.) * self.ct_weight

                merge_loss = ct_focal_loss(
                    torch.clamp(torch.sigmoid(pred_hm) * pred_centerness,
                                min=1e-4,
                                max=1 - 1e-4),
                    heatmap * centerness,
                    weight=(heatmap >= 0).float()) * self.merge_weight
            else:
                hm_loss = pred_hm.new_tensor([0.])
                centerness_loss = pred_hm.new_tensor([0.])
                merge_loss = ct_focal_loss(
                    torch.clamp(torch.sigmoid(pred_hm), min=1e-4,
                                max=1 - 1e-4),
                    heatmap * centerness,
                    weight=(heatmap >= 0).float()) * self.merge_weight

        if not self.wh_agnostic:
            pred_wh = pred_wh.view(pred_wh.size(0) * pred_hm.size(1), 4, H, W)
            box_target = box_target.view(
                box_target.size(0) * pred_hm.size(1), 4, H, W)
        mask = wh_weight.view(-1, H, W)
        avg_factor = mask.sum() + 1e-4

        if self.base_loc is None:
            base_step = self.down_ratio
            shifts_x = torch.arange(0, (W - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shifts_y = torch.arange(0, (H - 1) * base_step + 1,
                                    base_step,
                                    dtype=torch.float32,
                                    device=heatmap.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            self.base_loc = torch.stack((shift_x, shift_y), dim=0)  # (2, h, w)

        # (batch, h, w, 4)
        pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]],
                                self.base_loc + pred_wh[:, [2, 3]]),
                               dim=1).permute(0, 2, 3, 1)
        # (batch, h, w, 4)
        boxes = box_target.permute(0, 2, 3, 1)
        wh_loss = giou_loss(pred_boxes, boxes, mask,
                            avg_factor=avg_factor) * self.giou_weight

        return hm_loss, wh_loss, centerness_loss, merge_loss
Esempio n. 7
0
 def test_meshgrid(self):
     x = torch.ones(3, requires_grad=True)
     y = torch.zeros(4, requires_grad=True)
     z = torch.ones(5, requires_grad=True)
     self.assertONNX(lambda x, y, z: torch.meshgrid(x, y, z), (x, y, z))
Esempio n. 8
0
def warp_image_torch(img, H, out_shape=None):
    """Apply an homography to a torch Tensor

    Arguments:
        img: Tensor of shape (B,C,H,W) or (C,H,W)
        H: Tensor of shape (B,3,3) or (3,3), the homography
        out_shape: Tuple, the wanted shape of the out image
    Returns:
        A Tensor of shape (B) x (out_shape) or (B) x (img.shape), the warped image 
    Raises:
        ValueError: If img and H batch sizes are different
    """
    if out_shape is None:
        out_shape = img.shape[-2:]
    if len(img.shape) < 4:
        img = img[None]
    if len(H.shape) < 3:
        H = H[None]
    if img.shape[0] != H.shape[0]:
        raise ValueError(
            "batch size of images ({}) do not match the batch size of homographies ({})"
            .format(img.shape[0], H.shape[0]))
    batchsize = img.shape[0]
    # create grid for interpolation (in frame coordinates)

    y, x = torch.meshgrid([
        torch.linspace(-0.5, 0.5, steps=out_shape[-2]),
        torch.linspace(-0.5, 0.5, steps=out_shape[-1]),
    ])
    x = x.to(img.device)
    y = y.to(img.device)
    x, y = x.flatten(), y.flatten()

    # append ones for homogeneous coordinates
    xy = torch.stack([x, y, torch.ones_like(x)])
    xy = xy.repeat([batchsize, 1, 1])  # shape: (B, 3, N)
    # warp points to model coordinates
    xy_warped = torch.matmul(H, xy)  # H.bmm(xy)
    xy_warped, z_warped = xy_warped.split(2, dim=1)

    # we multiply by 2, since our homographies map to
    # coordinates in the range [-0.5, 0.5] (the ones in our GT datasets)
    xy_warped = 2.0 * xy_warped / (z_warped + 1e-8)
    x_warped, y_warped = torch.unbind(xy_warped, dim=1)
    # build grid
    grid = torch.stack(
        [
            x_warped.view(batchsize, *out_shape[-2:]),
            y_warped.view(batchsize, *out_shape[-2:]),
        ],
        dim=-1,
    )

    # sample warped image
    warped_img = torch.nn.functional.grid_sample(img,
                                                 grid,
                                                 mode="bilinear",
                                                 padding_mode="zeros")

    if hasnan(warped_img):
        print("nan value in warped image! set to zeros")
        warped_img[isnan(warped_img)] = 0

    return warped_img
Esempio n. 9
0
def render(seed=None,
           xlim=[-1.0, 1.0],
           ylim=None,
           xres=1024,
           yres=None,
           units=16,
           depth=8,
           hidden_std=1.0,
           output_std=1.0,
           channels=3,
           radius=True,
           bias=True,
           z=None,
           device='cpu'):
    if device not in get_devices():
        raise RuntimeError('Device {} not in available devices: {}'.format(
            device, ', '.join(get_devices())))

    cpu_rng_state = torch.get_rng_state()
    cuda_rng_states = []
    if torch.cuda.is_available():
        cuda_rng_states = [torch.cuda.get_rng_state(idx) for idx in range(torch.cuda.device_count())]

    if seed is None:
        seed = random.Random().randint(0, 2 ** 32 - 1)

    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)

    if ylim is None:
        ylim = copy.copy(xlim)

    if yres is None:
        yxscale = float(ylim[1] - ylim[0]) / (xlim[1] - xlim[0])
        yres = int(yxscale * xres)

    x = torch.linspace(xlim[0], xlim[1], xres, device=device)
    y = torch.linspace(ylim[0], ylim[1], yres, device=device)
    meshgrid_kwargs = {}
    if inspect.signature(torch.meshgrid).parameters.get('indexing'):
        meshgrid_kwargs['indexing'] = 'ij'
    grid = torch.meshgrid((y, x), **meshgrid_kwargs)

    inputs = torch.cat((grid[0].flatten().unsqueeze(1), grid[1].flatten().unsqueeze(1)), -1)

    if radius:
        inputs = torch.cat((inputs, torch.norm(inputs, 2, 1).unsqueeze(1)), -1)

    if z is not None:
        zrep = torch.tensor(z, dtype=inputs.dtype, device=device).repeat((inputs.shape[0], 1))
        inputs = torch.cat((inputs, zrep), -1)

    n_hidden_units = [units] * depth

    activations = inputs
    for units in n_hidden_units:
        if bias:
            bias_array = torch.ones((activations.shape[0], 1), device=device)
            activations = torch.cat((bias_array, activations), -1)
        hidden_layer_weights = torch.randn((activations.shape[1], units), device=device) * hidden_std
        activations = torch.tanh(torch.mm(activations, hidden_layer_weights))

    if bias:
        bias_array = torch.ones((activations.shape[0], 1), device=device)
        activations = torch.cat((bias_array, activations), -1)
    output_layer_weights = torch.randn((activations.shape[1], channels), device=device) * output_std
    output = torch.sigmoid(torch.mm(activations, output_layer_weights))
    output = output.reshape((yres, xres, channels))

    torch.set_rng_state(cpu_rng_state)
    for idx, cuda_rng_state in enumerate(cuda_rng_states):
        torch.cuda.set_rng_state(cuda_rng_state, idx)

    return (output.cpu() * 255).round().type(torch.uint8).numpy()
Esempio n. 10
0
def create_plots(robot, obstacles, dist_est, checker):
    from matplotlib.cm import get_cmap
    cmaps = [get_cmap('Reds'), get_cmap('Blues')]

    if robot.dof > 2:
        fig = plt.figure(figsize=(3, 3))
        ax = fig.add_subplot(111)  #, projection='3d'
    elif robot.dof == 2:
        # Show C-space at the same time
        num_class = getattr(checker, 'num_class', 1)
        fig = plt.figure(figsize=(3 * (num_class), 3 * (num_class + 1)))
        plt.rcParams.update({
            "text.usetex": True,
            "font.family": "sans-serif",
            "font.sans-serif": ["Helvetica"]
        })
        gs = fig.add_gridspec(num_class + 1, num_class)
        ax = fig.add_subplot(
            gs[:-1, :]
        )  #sum([list(range(r*(num_class+1)+1, (r+1)*(num_class+1))) for r in range(num_class)], [])) #, projection='3d'
        cfg_path_plots = []

        size = [400, 400]
        yy, xx = torch.meshgrid(torch.linspace(-np.pi, np.pi, size[0]),
                                torch.linspace(-np.pi, np.pi, size[1]))
        grid_points = torch.stack([xx, yy], dim=2).reshape((-1, 2))
        grid_points = grid_points.double(
        ) if checker.support_points.dtype == torch.float64 else grid_points
        score_spline = dist_est(grid_points).reshape(size + [num_class])
        c_axes = []
        with sns.axes_style('ticks'):
            for cat in range(num_class):
                c_ax = fig.add_subplot(gs[-1, cat])

                # score_DiffCo = checker.score(grid_points).reshape(size)
                # score = (torch.sign(score_DiffCo)+1)/2*(score_spline-score_spline.min()) + (-torch.sign(score_DiffCo)+1)/2*(score_spline-score_spline.max())
                score = score_spline[:, :, cat]
                color_mesh = c_ax.pcolormesh(xx,
                                             yy,
                                             score,
                                             cmap=cmaps[cat],
                                             vmin=-torch.abs(score).max(),
                                             vmax=torch.abs(score).max())
                c_support_points = checker.support_points[
                    checker.gains[:, cat] != 0]
                c_ax.scatter(c_support_points[:, 0],
                             c_support_points[:, 1],
                             marker='.',
                             c='black',
                             s=1.5)
                c_ax.contour(
                    xx,
                    yy,
                    score,
                    levels=[0],
                    linewidths=1,
                    alpha=0.4,
                )  #-1.5, -0.75, 0, 0.3
                # fig.colorbar(color_mesh, ax=c_ax)
                # sparse_score = score[5:-5:10, 5:-5:10]
                # score_grad_x = -ndimage.sobel(sparse_score.numpy(), axis=1)
                # score_grad_y = -ndimage.sobel(sparse_score.numpy(), axis=0)
                # score_grad = np.stack([score_grad_x, score_grad_y], axis=2)
                # score_grad /= np.linalg.norm(score_grad, axis=2, keepdims=True)
                # score_grad_x, score_grad_y = score_grad[:, :, 0], score_grad[:, :, 1]
                # c_ax.quiver(xx[5:-5:10, 5:-5:10], yy[5:-5:10, 5:-5:10], score_grad_x, score_grad_y, color='red', width=2e-3, headwidth=2, headlength=5)
                # cfg_point = Circle(collision_cfgs[0], radius=0.05, facecolor='orange', edgecolor='black', path_effects=[path_effects.withSimplePatchShadow()])
                # c_ax.add_patch(cfg_point)
                for _ in range(4):
                    cfg_path, = c_ax.plot([], [],
                                          '-o',
                                          c='olivedrab',
                                          markersize=3)
                    cfg_path_plots.append(cfg_path)

                c_ax.set_aspect('equal', adjustable='box')
                # c_ax.axis('equal')
                c_ax.set_xlim(-np.pi, np.pi)
                c_ax.set_ylim(-np.pi, np.pi)
                c_ax.set_xticks([-np.pi, 0, np.pi])
                c_ax.set_xticklabels(['$-\pi$', '$0$', '$\pi$'])
                c_ax.set_yticks([-np.pi, 0, np.pi])
                c_ax.set_yticklabels(['$-\pi$', '$0$', '$\pi$'])
                # c_ax.tick_params(direction='in', reset=True)
                # c_ax.tick_params(which='both', direction='out', length=6, width=2, colors='r',
                #    grid_color='r', grid_alpha=0.5)
            # c_ax.set_ticks('')

    # Plot ostacles
    # ax.axis('tight')
    ax.set_xlim(-8, 8)
    ax.set_ylim(-8, 8)
    ax.set_aspect('equal', adjustable='box')
    ax.set_xticks([-4, 0, 4])
    ax.set_yticks([-4, 0, 4])
    for obs in obstacles:
        cat = obs[3] if len(obs) >= 4 else 1
        if obs[0] == 'circle':
            ax.add_patch(
                Circle(obs[1],
                       obs[2],
                       path_effects=[path_effects.withSimplePatchShadow()],
                       color=cmaps[cat](0.5)))
        elif obs[0] == 'rect':
            ax.add_patch(
                Rectangle((obs[1][0] - float(obs[2][0]) / 2,
                           obs[1][1] - float(obs[2][1]) / 2),
                          obs[2][0],
                          obs[2][1],
                          path_effects=[path_effects.withSimplePatchShadow()],
                          color=cmaps[cat](0.5)))
            # print((obs[1][0]-obs[2][0]/2, obs[1][1]-obs[2][1]/2))

    # Placeholder of the robot plot
    trans = ax.transData.transform
    lw = ((trans((1, robot.link_width)) - trans(
        (0, 0))) * 72 / ax.figure.dpi)[1]
    link_plot, = ax.plot(
        [], [],
        color='silver',
        alpha=0.1,
        lw=lw,
        solid_capstyle='round',
        path_effects=[path_effects.SimpleLineShadow(),
                      path_effects.Normal()])
    joint_plot, = ax.plot([], [], 'o', color='tab:red', markersize=lw)
    eff_plot, = ax.plot([], [], 'o', color='black', markersize=lw)

    if robot.dof > 2:
        return fig, ax, link_plot, joint_plot, eff_plot
    elif robot.dof == 2:
        return fig, ax, link_plot, joint_plot, eff_plot, cfg_path_plots
Esempio n. 11
0
    def fcos_losses(
        self,
        labels,
        reg_targets,
        logits_pred,
        reg_pred,
        ctrness_pred,
        controllers_pred,
        focal_loss_alpha,
        focal_loss_gamma,
        iou_loss,
        matched_idxes,
        im_idxes,
        locations,
    ):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = (sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=focal_loss_alpha,
            gamma=focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg)

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        controllers_pred = controllers_pred[pos_inds]
        matched_idxes = matched_idxes[pos_inds]
        im_idxes = im_idxes[pos_inds]
        locations = locations[pos_inds]

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = iou_loss(reg_pred, reg_targets,
                            ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        # for CondInst
        batch_ins = pos_inds.shape[0]
        N, C, h, w = self.masks.shape
        center_x = torch.clamp(locations[:, 0], min=0, max=w - 1).long()
        center_y = torch.clamp(locations[:, 1], min=0, max=h - 1).long()
        x_range = torch.linspace(-1, 1, w, device=self.masks.device)
        y_range = torch.linspace(-1, 1, h, device=self.masks.device)
        y, x = torch.meshgrid(y_range, x_range)
        x = x.unsqueeze(0).unsqueeze(0)
        y = y.unsqueeze(0).unsqueeze(0)
        grid = torch.cat([x, y], 1)
        offset_x = x_range[center_x].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        offset_y = y_range[center_y].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        offset_xy = torch.cat([offset_x, offset_y], 1)
        coords_feat = grid - offset_xy
        masks_feat = self.masks
        r_h = int(h * self.strides[0])
        r_w = int(w * self.strides[0])
        targets_masks = [
            target_im.gt_masks.tensor for target_im in self.gt_instances
        ]
        masks_t = self.prepare_masks(h, w, r_h, r_w, targets_masks)
        mask_loss = masks_feat[0].new_tensor(0.0)
        batch_ins = im_idxes.shape[0]
        # for each image
        for i in range(N):
            inds = (im_idxes == i).nonzero().flatten()
            ins_num = inds.shape[0]
            if ins_num > 0:
                controllers = controllers_pred[inds]
                coord_feat = coords_feat[inds]
                mask_feat = masks_feat[None, i]
                mask_feat = torch.cat([mask_feat] * ins_num, dim=0)
                comb_feat = torch.cat((mask_feat, coord_feat),
                                      dim=1).view(1, -1, h, w)
                weight1, bias1, weight2, bias2, weight3, bias3 = torch.split(
                    controllers, [80, 8, 64, 8, 8, 1], dim=1)
                bias1, bias2, bias3 = bias1.flatten(), bias2.flatten(
                ), bias3.flatten()
                weight1 = weight1.reshape(-1, 8, 10).reshape(
                    -1, 10).unsqueeze(-1).unsqueeze(-1)
                weight2 = weight2.reshape(-1, 8, 8).reshape(
                    -1, 8).unsqueeze(-1).unsqueeze(-1)
                weight3 = weight3.unsqueeze(-1).unsqueeze(-1)
                conv1 = F.conv2d(comb_feat, weight1, bias1,
                                 groups=ins_num).relu()
                conv2 = F.conv2d(conv1, weight2, bias2, groups=ins_num).relu()
                masks_per_image = F.conv2d(conv2,
                                           weight3,
                                           bias3,
                                           groups=ins_num)
                masks_per_image = aligned_bilinear(
                    masks_per_image, self.strides[0])[0].sigmoid()
                for j in range(ins_num):
                    ind = inds[j]
                    mask_gt = masks_t[i][matched_idxes[ind]].float()
                    mask_pred = masks_per_image[j]
                    mask_loss += self.dice_loss(mask_pred, mask_gt)

        if batch_ins > 0:
            mask_loss = mask_loss / batch_ins

        losses = {
            "loss_fcos_cls": class_loss,
            "loss_fcos_loc": reg_loss,
            "loss_fcos_ctr": ctrness_loss,
            "loss_mask": mask_loss,
        }
        return losses, {}
Esempio n. 12
0
    def combine(self, local_parameters, param_idx, user_idx):
        count = OrderedDict()
        if cfg['model_name'] == 'conv':
            output_weight_name = [
                k for k in self.global_parameters.keys() if 'weight' in k
            ][-1]
            output_bias_name = [
                k for k in self.global_parameters.keys() if 'bias' in k
            ][-1]
            for k, v in self.global_parameters.items():
                parameter_type = k.split('.')[-1]
                count[k] = v.new_zeros(v.size(), dtype=torch.float32)
                tmp_v = v.new_zeros(v.size(), dtype=torch.float32)
                for m in range(len(local_parameters)):
                    if 'weight' in parameter_type or 'bias' in parameter_type:
                        if parameter_type == 'weight':
                            if v.dim() > 1:
                                if k == output_weight_name:
                                    label_split = self.label_split[user_idx[m]]
                                    param_idx[m][k] = list(param_idx[m][k])
                                    param_idx[m][k][0] = param_idx[m][k][0][
                                        label_split]
                                    tmp_v[torch.meshgrid(
                                        param_idx[m][k]
                                    )] += local_parameters[m][k][label_split]
                                    count[k][torch.meshgrid(
                                        param_idx[m][k])] += 1
                                else:
                                    tmp_v[torch.meshgrid(
                                        param_idx[m]
                                        [k])] += local_parameters[m][k]
                                    count[k][torch.meshgrid(
                                        param_idx[m][k])] += 1
                            else:
                                tmp_v[param_idx[m]
                                      [k]] += local_parameters[m][k]
                                count[k][param_idx[m][k]] += 1
                        else:
                            if k == output_bias_name:
                                label_split = self.label_split[user_idx[m]]
                                param_idx[m][k] = param_idx[m][k][label_split]
                                tmp_v[param_idx[m][k]] += local_parameters[m][
                                    k][label_split]
                                count[k][param_idx[m][k]] += 1
                            else:
                                tmp_v[param_idx[m]
                                      [k]] += local_parameters[m][k]
                                count[k][param_idx[m][k]] += 1
                    else:
                        tmp_v += local_parameters[m][k]
                        count[k] += 1
                tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(
                    count[k][count[k] > 0])
                v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype)
        elif 'resnet' in cfg['model_name']:
            for k, v in self.global_parameters.items():
                parameter_type = k.split('.')[-1]
                count[k] = v.new_zeros(v.size(), dtype=torch.float32)
                tmp_v = v.new_zeros(v.size(), dtype=torch.float32)
                for m in range(len(local_parameters)):
                    if 'weight' in parameter_type or 'bias' in parameter_type:
                        if parameter_type == 'weight':
                            if v.dim() > 1:
                                if 'linear' in k:
                                    label_split = self.label_split[user_idx[m]]
                                    param_idx[m][k] = list(param_idx[m][k])
                                    param_idx[m][k][0] = param_idx[m][k][0][
                                        label_split]
                                    tmp_v[torch.meshgrid(
                                        param_idx[m][k]
                                    )] += local_parameters[m][k][label_split]
                                    count[k][torch.meshgrid(
                                        param_idx[m][k])] += 1
                                else:
                                    tmp_v[torch.meshgrid(
                                        param_idx[m]
                                        [k])] += local_parameters[m][k]
                                    count[k][torch.meshgrid(
                                        param_idx[m][k])] += 1
                            else:
                                tmp_v[param_idx[m]
                                      [k]] += local_parameters[m][k]
                                count[k][param_idx[m][k]] += 1
                        else:
                            if 'linear' in k:
                                label_split = self.label_split[user_idx[m]]
                                param_idx[m][k] = param_idx[m][k][label_split]
                                tmp_v[param_idx[m][k]] += local_parameters[m][
                                    k][label_split]
                                count[k][param_idx[m][k]] += 1
                            else:
                                tmp_v[param_idx[m]
                                      [k]] += local_parameters[m][k]
                                count[k][param_idx[m][k]] += 1
                    else:
                        tmp_v += local_parameters[m][k]
                        count[k] += 1
                tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(
                    count[k][count[k] > 0])
                v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype)
        elif cfg['model_name'] == 'transformer':
            for k, v in self.global_parameters.items():
                parameter_type = k.split('.')[-1]
                count[k] = v.new_zeros(v.size(), dtype=torch.float32)
                tmp_v = v.new_zeros(v.size(), dtype=torch.float32)
                for m in range(len(local_parameters)):
                    if 'weight' in parameter_type or 'bias' in parameter_type:
                        if 'weight' in parameter_type:
                            if v.dim() > 1:
                                if k.split('.')[-2] == 'embedding':
                                    label_split = self.label_split[user_idx[m]]
                                    param_idx[m][k] = list(param_idx[m][k])
                                    param_idx[m][k][0] = param_idx[m][k][0][
                                        label_split]
                                    tmp_v[torch.meshgrid(
                                        param_idx[m][k]
                                    )] += local_parameters[m][k][label_split]
                                    count[k][torch.meshgrid(
                                        param_idx[m][k])] += 1
                                elif 'decoder' in k and 'linear2' in k:
                                    label_split = self.label_split[user_idx[m]]
                                    param_idx[m][k] = list(param_idx[m][k])
                                    param_idx[m][k][0] = param_idx[m][k][0][
                                        label_split]
                                    tmp_v[torch.meshgrid(
                                        param_idx[m][k]
                                    )] += local_parameters[m][k][label_split]
                                    count[k][torch.meshgrid(
                                        param_idx[m][k])] += 1
                                else:
                                    tmp_v[torch.meshgrid(
                                        param_idx[m]
                                        [k])] += local_parameters[m][k]
                                    count[k][torch.meshgrid(
                                        param_idx[m][k])] += 1
                            else:
                                tmp_v[param_idx[m]
                                      [k]] += local_parameters[m][k]
                                count[k][param_idx[m][k]] += 1
                        else:
                            if 'decoder' in k and 'linear2' in k:
                                label_split = self.label_split[user_idx[m]]
                                param_idx[m][k] = param_idx[m][k][label_split]
                                tmp_v[param_idx[m][k]] += local_parameters[m][
                                    k][label_split]
                                count[k][param_idx[m][k]] += 1
                            else:
                                tmp_v[param_idx[m]
                                      [k]] += local_parameters[m][k]
                                count[k][param_idx[m][k]] += 1
                    else:
                        tmp_v += local_parameters[m][k]
                        count[k] += 1
                tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(
                    count[k][count[k] > 0])
                v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype)
        else:
            raise ValueError('Not valid model name')

        return
Esempio n. 13
0
 def meshgrid(a: torch.Tensor,
              b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     x = torch.meshgrid(a, b, indexing="ij")
     return x[0], x[1]
Esempio n. 14
0
        if (epoch == 0) and (testing == True):
            print('TARGETS: ', targets)
            model.train(False)
            encodings = model.test_encode(batch)
            print('ENCODINGS: ', encodings)


# =============================================================================
# Setup the auxilliary tensors for the training run
# =============================================================================

batch_size = 64

l = torch.arange(0 - 63 / 2., (63 / 2.) + 1)
yyy, xxx, zzz = torch.meshgrid(l, l, l)

xxx, yyy, zzz = xxx.repeat(batch_size, 1, 1,
                           1), yyy.repeat(batch_size, 1, 1,
                                          1), zzz.repeat(batch_size, 1, 1, 1)
xxx = xxx.to(device).to(torch.float)
yyy = yyy.to(device).to(torch.float)
zzz = zzz.to(device).to(torch.float)

mask = torch.zeros((xxx.shape)).to(device).to(torch.float)
thresh = torch.tensor([3]).to(device).to(torch.float)

# =============================================================================
# Instantiate the model
# =============================================================================
Esempio n. 15
0
def proj_cost(settings, ref_feature, src_feature, level, ref_in, src_in,
              ref_ex, src_ex, depth_hypos):
    ## Calculate the cost volume for refined depth hypothesis selection

    batch, channels = ref_feature.shape[0], ref_feature.shape[1]
    num_depth = depth_hypos.shape[1]
    height, width = ref_feature.shape[2], ref_feature.shape[3]
    nSrc = len(src_feature)

    volume_sum = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1)
    volume_sq_sum = volume_sum.pow_(2)

    for src in range(settings.nsrc):

        with torch.no_grad():
            src_proj = torch.matmul(src_in[:, src, :, :], src_ex[:, src,
                                                                 0:3, :])
            ref_proj = torch.matmul(ref_in, ref_ex[:, 0:3, :])
            last = torch.tensor([[[0, 0, 0, 1.0]]]).repeat(len(src_in), 1,
                                                           1).cuda()
            src_proj = torch.cat((src_proj, last), 1)
            ref_proj = torch.cat((ref_proj, last), 1)

            proj = torch.matmul(src_proj, torch.inverse(ref_proj))
            rot = proj[:, :3, :3]
            trans = proj[:, :3, 3:4]

            y, x = torch.meshgrid([
                torch.arange(0,
                             height,
                             dtype=torch.float32,
                             device=ref_feature.device),
                torch.arange(0,
                             width,
                             dtype=torch.float32,
                             device=ref_feature.device)
            ])
            y, x = y.contiguous(), x.contiguous()
            y, x = y.view(height * width), x.view(height * width)
            xyz = torch.stack((x, y, torch.ones_like(x)))
            xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)
            rot_xyz = torch.matmul(rot, xyz)

            rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(
                1, 1, num_depth, 1) * depth_hypos.view(
                    batch, 1, num_depth, height * width)  # [B, 3, Ndepth, H*W]
            proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1)
            proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]
            proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
            proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
            proj_xy = torch.stack((proj_x_normalized, proj_y_normalized),
                                  dim=3)
            grid = proj_xy

        warped_src_fea = F.grid_sample(src_feature[src][level],
                                       grid.view(batch, num_depth * height,
                                                 width, 2),
                                       mode='bilinear',
                                       padding_mode='zeros')
        warped_src_fea = warped_src_fea.view(batch, channels, num_depth,
                                             height, width)

        volume_sum = volume_sum + warped_src_fea
        volume_sq_sum = volume_sq_sum + warped_src_fea.pow_(2)

    cost_volume = volume_sq_sum.div_(settings.nsrc + 1).sub_(
        volume_sum.div_(settings.nsrc + 1).pow_(2))

    if settings.mode == "test":
        del volume_sum
        del volume_sq_sum
        torch.cuda.empty_cache()

    return cost_volume
Esempio n. 16
0
def calDepthHypo(netArgs, ref_depths, ref_intrinsics, src_intrinsics,
                 ref_extrinsics, src_extrinsics, depth_min, depth_max, level):
    ## Calculate depth hypothesis maps for refine steps

    nhypothesis_init = 48
    d = 4
    pixel_interval = 1

    nBatch = ref_depths.shape[0]
    height = ref_depths.shape[1]
    width = ref_depths.shape[2]

    if netArgs.mode == "train":

        depth_interval = torch.tensor([6.8085] * nBatch).cuda(
        )  # Hard code the interval for training on DTU with 1 level of refinement.
        depth_hypos = ref_depths.unsqueeze(1).repeat(1, d * 2, 1, 1)
        for depth_level in range(-d, d):
            depth_hypos[:, depth_level +
                        d, :, :] += (depth_level) * depth_interval[0]

        return depth_hypos

    with torch.no_grad():

        ref_depths = ref_depths
        ref_intrinsics = ref_intrinsics
        src_intrinsics = src_intrinsics.squeeze(1)
        ref_extrinsics = ref_extrinsics
        src_extrinsics = src_extrinsics.squeeze(1)

        interval_maps = []
        depth_hypos = ref_depths.unsqueeze(1).repeat(1, d * 2, 1, 1)
        for batch in range(nBatch):
            xx, yy = torch.meshgrid(
                [torch.arange(0, width),
                 torch.arange(0, height)])

            xxx = xx.reshape([-1]).cuda().float()
            yyy = yy.reshape([-1]).cuda().float()

            X = torch.stack([xxx, yyy, torch.ones_like(xxx)], dim=0)

            D1 = torch.transpose(ref_depths[batch, :, :], 0, 1).reshape(
                [-1]
            )  # Transpose before reshape to produce identical results to numpy and matlab version.
            # D2 = D1+1
            X1 = X * D1

            D1 = D1 + 1
            X2 = X * D1

            ray1 = torch.matmul(torch.inverse(ref_intrinsics[batch]), X1)
            ray2 = torch.matmul(torch.inverse(ref_intrinsics[batch]), X2)

            X1 = torch.cat([ray1, X[-1, :].unsqueeze(0)], dim=0)
            X1 = torch.matmul(torch.inverse(ref_extrinsics[batch]), X1)
            X2 = torch.cat([ray2, X[-1, :].unsqueeze(0)], dim=0)
            X2 = torch.matmul(torch.inverse(ref_extrinsics[batch]), X2)

            X1 = torch.matmul(src_extrinsics[batch][0], X1)
            X2 = torch.matmul(src_extrinsics[batch][0], X2)

            X1 = X1[:3]
            X1 = torch.matmul(src_intrinsics[batch][0], X1)
            X1_d = X1[2].clone()
            X1 /= X1_d

            X2 = X2[:3]
            X2 = torch.matmul(src_intrinsics[batch][0], X2)
            # X2_d = X2[2].clone()
            X2 = X2 / X2[2]

            k = (X2[1] - X1[1]) / (X2[0] - X1[0])
            # b = X1[1]-k*X1[0]
            del X2

            theta = torch.atan(k)
            X3 = X1 + torch.stack([
                torch.cos(theta) * pixel_interval,
                torch.sin(theta) * pixel_interval,
                torch.zeros_like(X1[2, :])
            ],
                                  dim=0)

            A = torch.matmul(ref_intrinsics[batch],
                             ref_extrinsics[batch][:3, :3])
            tmp = torch.matmul(src_intrinsics[batch][0],
                               src_extrinsics[batch][0, :3, :3])
            A = torch.matmul(A, torch.inverse(tmp))

            tmp1 = X1_d * torch.matmul(A, X1)
            tmp2 = torch.matmul(A, X3)

            M1 = torch.cat(
                [X.t().unsqueeze(2), tmp2.t().unsqueeze(2)], axis=2)[:, 1:, :]
            M2 = tmp1.t()[:, 1:]
            ans = torch.matmul(torch.inverse(M1), M2.unsqueeze(2))
            delta_d = ans[:, 0, 0]

            interval_maps = torch.abs(delta_d).mean().repeat(
                ref_depths.shape[2], ref_depths.shape[1]).t()

            for depth_level in range(-d, d):
                depth_hypos[batch, depth_level +
                            d, :, :] += depth_level * interval_maps

        # print("Calculated:")
        # print(interval_maps[0,0])

        # pdb.set_trace()

        return depth_hypos  # Return the depth hypothesis map from statistical interval setting.
def grid(W):
    y, x = torch.meshgrid([torch.arange(0., W).type(dtype) / W] * 2)
    return torch.stack((x, y), dim=2).view(-1, 2)
    def _train_on_epoch(self, model, optimizer):
        model.train()

        data_loader = self.data_loaders[0]
        running_loss_XYZ = 0.0
        running_loss_mask = 0.0
        running_loss = 0.0

        for self.iteration, batch in enumerate(data_loader, self.iteration):
            input_images, depthGT, maskGT = utils.unpack_batch_fixed(batch, self.cfg.device)
            # ------ define ground truth------
            XGT, YGT = torch.meshgrid([
                torch.arange(self.cfg.outH), # [H,W]
                torch.arange(self.cfg.outW)]) # [H,W]
            XGT, YGT = XGT.float(), YGT.float()
            XYGT = torch.cat([
                XGT.repeat([self.cfg.outViewN, 1, 1]), 
                YGT.repeat([self.cfg.outViewN, 1, 1])], dim=0) #[2V,H,W]
            XYGT = XYGT.unsqueeze(dim=0).to(self.cfg.device) # [1,2V,H,W] 

            with torch.set_grad_enabled(True):
                optimizer.zero_grad()

                XYZ, maskLogit = model(input_images)
                XY = XYZ[:, :self.cfg.outViewN * 2, :, :]
                depth = XYZ[:, self.cfg.outViewN * 2:self.cfg.outViewN * 3, :,  :]
                mask = (maskLogit > 0).byte()
                # ------ Compute loss ------
                loss_XYZ = self.l1(XY, XYGT)
                loss_XYZ += self.l1(depth.masked_select(mask),
                                    depthGT.masked_select(mask))
                loss_mask = self.sigmoid_bce(maskLogit, maskGT)
                loss = loss_mask + self.cfg.lambdaDepth * loss_XYZ

                # ------ Update weights ------
                loss.backward()
                # True Weight decay
                if self.cfg.trueWD is not None:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data.add_(
                                -self.cfg.trueWD * group['lr'], param.data)
                optimizer.step()

            if self.on_after_batch is not None:
                if self.cfg.lrSched.lower() in "cyclical":
                    self.on_after_batch(self.iteration)
                else: self.on_after_batch(self.epoch)

            running_loss_XYZ += loss_XYZ.item() * input_images.size(0)
            running_loss_mask += loss_mask.item() * input_images.size(0)
            running_loss += loss.item() * input_images.size(0)

        epoch_loss_XYZ = running_loss_XYZ / len(data_loader.dataset)
        epoch_loss_mask = running_loss_mask / len(data_loader.dataset)
        epoch_loss = running_loss / len(data_loader.dataset)

        print(f"\tTrain loss: {epoch_loss}")
        return {"epoch_loss_XYZ": epoch_loss_XYZ,
                "epoch_loss_mask": epoch_loss_mask,
                "epoch_loss": epoch_loss, }
Esempio n. 19
0
    def forward(self, x, y, loss_mask=None):

        # outputs, predictions = x
        # ground truths = y

        self.image_dimensions = x.shape
        _, _, h, w, d = self.image_dimensions  # h x w x d --> can't be in the init part because x is defined only in the forward. (maybe put a self there, and call x from init??)
        coordinates_map = torch.meshgrid(
            [torch.arange(h),
             torch.arange(w),
             torch.arange(d)])
        self.coords = coordinates_map

        sour, targ, dy_gt, dx_gt, dd_gt = self.GT_sr

        self.compute_centroids(x)

        # self.compute_colors(data, x)   --> data must be accessed from here. Not done.

        sum_err = 0
        # for each prior relation, compute distance/color in the prediction and compare it with the ground truth.
        for rel in range(len(sour)):

            sour0 = int(sour[rel])
            targ0 = int(targ[rel])

            dy = self.centroids_y[sour0] - self.centroids_y[targ0]
            dx = self.centroids_x[sour0] - self.centroids_x[targ0]
            dd = self.centroids_d[sour0] - self.centroids_d[targ0]
            """print('rel is: ', rel, ' --------- source: ', sour0, '------- targ: ', targ0, '-----------')
            print('dy is:', dy)
            print('dx is:', dx)
            print('dd is:', dd)
            #dc = self.colors[:, i[rel]] - self.colors[:, j[rel]]                  
            print('dy_gt is:', dy_gt[rel])
            print('dx_gt is:', dx_gt[rel])
            print('dd_gt is:', dd_gt[rel])"""

            diff_y = (dy - dy_gt[rel]) / self.image_dimensions[3]
            diff_x = (dx - dx_gt[rel]) / self.image_dimensions[2]
            diff_d = (dd - dd_gt[rel]) / self.image_dimensions[4]
            #diff_c = (dc - dc_gt[rel])

            #print('diff_y is:', diff_y)
            #print('diff_x is:', diff_x)
            #print('diff_d is:', diff_d)

            dy_err = torch.mean(torch.square(weird_to_num(diff_y,
                                                          replace=0.0)))
            dx_err = torch.mean(torch.square(weird_to_num(diff_x,
                                                          replace=0.0)))
            dd_err = torch.mean(torch.square(weird_to_num(diff_d,
                                                          replace=0.0)))
            #dc_err = torch.mean(torch.square(weird_to_num(diff_c, replace=0.0)))

            #print('dy_err is:', dy_err)
            #print('dx_err is:', dx_err)
            #print('dd_err is:', dd_err)

            sum_err += dy_err + dx_err + dd_err  # + dc_err

        return sum_err
Esempio n. 20
0
    def graph_layer(self, encoded_seq, info, word_sec, section, positions):
        """
        Graph Layer -> Construct a document-level graph
        The graph edges hold representations for the connections between the nodes.
        Args:
            encoded_seq: Encoded sequence, shape (sentences, words, dimension)
            info:        (Tensor, 5 columns) entity_id, entity_type, start_wid, end_wid, sentence_id
            word_sec:    (Tensor) number of words per sentence
            section:     (Tensor <B, 3>) #entities/#mentions/#sentences per batch
            positions:   distances between nodes (only M-M and S-S)

        Returns: (Tensor) graph, (Tensor) tensor_mapping, (Tensors) indices, (Tensor) node information
        """
        # SENTENCE NODES
        sentences = torch.mean(encoded_seq,
                               dim=1)  # sentence nodes (avg of sentence words)

        # MENTION & ENTITY NODES
        temp_ = torch.arange(word_sec.max()).unsqueeze(0).repeat(
            sentences.size(0), 1).to(self.device)
        remove_pad = (temp_ < word_sec.unsqueeze(1))

        mentions = self.merge_tokens(info, encoded_seq,
                                     remove_pad)  # mention nodes
        entities = self.merge_mentions(info, mentions)  # entity nodes

        # all nodes in order: entities - mentions - sentences
        nodes = torch.cat((entities, mentions, sentences),
                          dim=0)  # e + m + s (all)
        nodes_info = self.node_info(
            section,
            info)  # info/node: node type | semantic type | sentence ID

        if self.types:  # + node types
            nodes = torch.cat((nodes, self.type_embed(nodes_info[:, 0])),
                              dim=1)

        # re-order nodes per document (batch)
        nodes = self.rearrange_nodes(nodes, section)
        nodes = self.split_n_pad(nodes, section, pad=0)

        nodes_info = self.rearrange_nodes(nodes_info, section)
        nodes_info = self.split_n_pad(nodes_info, section, pad=-1)

        # create initial edges (concat node representations)
        r_idx, c_idx = torch.meshgrid(
            torch.arange(nodes.size(1)).to(self.device),
            torch.arange(nodes.size(1)).to(self.device))
        graph = torch.cat((nodes[:, r_idx], nodes[:, c_idx]), dim=3)
        r_id, c_id = nodes_info[..., 0][:, r_idx], nodes_info[
            ..., 0][:, c_idx]  # node type indicators

        # pair masks
        pid = self.pair_ids(r_id, c_id)

        # Linear reduction layers
        reduced_graph = torch.where(
            pid['MS'].unsqueeze(-1), self.reduce['MS'](graph),
            torch.zeros(graph.size()[:-1] + (self.out_dim, )).to(self.device))
        reduced_graph = torch.where(pid['ME'].unsqueeze(-1),
                                    self.reduce['ME'](graph), reduced_graph)
        reduced_graph = torch.where(pid['ES'].unsqueeze(-1),
                                    self.reduce['ES'](graph), reduced_graph)

        if self.dist:
            dist_vec = self.dist_embed(positions)  # distances
            reduced_graph = torch.where(
                pid['SS'].unsqueeze(-1), self.reduce['SS'](torch.cat(
                    (graph, dist_vec), dim=3)), reduced_graph)
        else:
            reduced_graph = torch.where(pid['SS'].unsqueeze(-1),
                                        self.reduce['SS'](graph),
                                        reduced_graph)

        if self.context and self.dist:
            m_cntx = self.attention(mentions, encoded_seq[info[:, 4]], info,
                                    word_sec)
            m_cntx = self.prepare_mention_context(m_cntx, section, r_idx,
                                                  c_idx, encoded_seq[info[:,
                                                                          4]],
                                                  pid, nodes_info)

            reduced_graph = torch.where(
                pid['MM'].unsqueeze(-1), self.reduce['MM'](torch.cat(
                    (graph, dist_vec, m_cntx), dim=3)), reduced_graph)

        elif self.context:
            m_cntx = self.attention(mentions, encoded_seq[info[:, 4]], info,
                                    word_sec)
            m_cntx = self.prepare_mention_context(m_cntx, section, r_idx,
                                                  c_idx, encoded_seq[info[:,
                                                                          4]],
                                                  pid, nodes_info)

            reduced_graph = torch.where(
                pid['MM'].unsqueeze(-1), self.reduce['MM'](torch.cat(
                    (graph, m_cntx), dim=3)), reduced_graph)

        elif self.dist:
            reduced_graph = torch.where(
                pid['MM'].unsqueeze(-1), self.reduce['MM'](torch.cat(
                    (graph, dist_vec), dim=3)), reduced_graph)

        else:
            reduced_graph = torch.where(pid['MM'].unsqueeze(-1),
                                        self.reduce['MM'](graph),
                                        reduced_graph)

        if self.ee:
            reduced_graph = torch.where(pid['EE'].unsqueeze(-1),
                                        self.reduce['EE'](graph),
                                        reduced_graph)

        mask = self.get_nodes_mask(section.sum(dim=1))
        return reduced_graph, (r_idx, c_idx), nodes_info, mask
Esempio n. 21
0
def generate_region_meshgrid(num_regions,region_size,region_offsets):
    hh,ww = torch.meshgrid(torch.arange(0,num_regions[0]),torch.arange(0,num_regions[1]))
    hh = (hh*region_size[0]+region_offsets[0]).cuda().float()
    ww = (ww*region_size[1]+region_offsets[1]).cuda().float()
    return hh,ww
Esempio n. 22
0
def eval(args, unet, imnet, eval_loader, epoch, global_step, device, logger,
         writer, optimizer, pde_layer):
    """Eval function. Used for evaluating entire slices and comparing to GT."""
    unet.eval()
    imnet.eval()
    phys_channels = ["p", "b", "u", "w"]
    phys2id = dict(zip(phys_channels, range(len(phys_channels))))
    xmin = torch.zeros(3, dtype=torch.float32).to(device)
    xmax = torch.ones(3, dtype=torch.float32).to(device)
    for data_tensors in eval_loader:
        # only need the first batch
        break
    # send tensors to device
    data_tensors = [t.to(device) for t in data_tensors]
    hres_grid, lres_grid, _, _ = data_tensors
    latent_grid = unet(lres_grid)  # [batch, C, T, Z, X]
    nb, nc, nt, nz, nx = hres_grid.shape

    # permute such that C is the last channel for local implicit grid query
    latent_grid = latent_grid.permute(0, 2, 3, 4, 1)  # [batch, T, Z, X, C]

    # define lambda function for pde_layer
    fwd_fn = lambda points: query_local_implicit_grid(imnet, latent_grid,
                                                      points, xmin, xmax)

    # update pde layer and compute predicted values + pde residues
    pde_layer.update_forward_method(fwd_fn)

    # layout query points for the desired slices
    eps = 1e-6
    t_seq = torch.linspace(eps, 1 - eps,
                           nt)[::int(nt / 8)]  # temporal sequences
    z_seq = torch.linspace(eps, 1 - eps, nz)  # z sequences
    x_seq = torch.linspace(eps, 1 - eps, nx)  # x sequences

    query_coord = torch.stack(torch.meshgrid(t_seq, z_seq, x_seq),
                              axis=-1)  # [nt, nz, nx, 3]
    query_coord = query_coord.reshape([-1, 3]).to(device)  # [nt*nz*nx, 3]
    n_query = query_coord.shape[0]

    res_dict = defaultdict(list)

    n_iters = int(np.ceil(n_query / args.pseudo_batch_size))

    for idx in range(n_iters):
        sid = idx * args.pseudo_batch_size
        eid = min(sid + args.pseudo_batch_size, n_query)
        query_coord_batch = query_coord[sid:eid]
        query_coord_batch = query_coord_batch[None].expand(
            *(nb, eid - sid, 3))  # [nb, eid-sid, 3]

        pred_value, residue_dict = pde_layer(query_coord_batch,
                                             return_residue=True)
        pred_value = pred_value.detach()
        for key in residue_dict.keys():
            residue_dict[key] = residue_dict[key].detach()
        for name, chan_id in zip(phys_channels, range(4)):
            res_dict[name].append(pred_value[..., chan_id])  # [b, pb]
        for name, val in residue_dict.items():
            res_dict[name].append(val[..., 0])  # [b, pb]

    for key in res_dict.keys():
        res_dict[key] = (torch.cat(res_dict[key], axis=1).reshape(
            [nb, len(t_seq), len(z_seq),
             len(x_seq)]))

    # log the imgs sample-by-sample
    for samp_id in range(nb):
        for key in res_dict.keys():
            field = res_dict[key][samp_id]  # [nt, nz, nx]
            # add predicted slices
            images = utils.batch_colorize_scalar_tensors(
                field)  # [nt, nz, nx, 3]

            writer.add_images('sample_{}/{}/predicted'.format(samp_id, key),
                              images,
                              dataformats='NHWC',
                              global_step=int(global_step))
            # add ground truth slices (only for phys channels)
            if key in phys_channels:
                gt_fields = hres_grid[samp_id,
                                      phys2id[key], ::int(nt /
                                                          8)]  # [nt, nz, nx]
                gt_images = utils.batch_colorize_scalar_tensors(
                    gt_fields)  # [nt, nz, nx, 3]

                writer.add_images('sample_{}/{}/ground_truth'.format(
                    samp_id, key),
                                  gt_images,
                                  dataformats='NHWC',
                                  global_step=int(global_step))
Esempio n. 23
0
def render_singletask_gp(
    ax: [plt.Axes, Axes3D, Sequence[plt.Axes]],
    data_x: to.Tensor,
    data_y: to.Tensor,
    idcs_sel: list,
    data_x_min: to.Tensor = None,
    data_x_max: to.Tensor = None,
    x_label: str = '',
    y_label: str = '',
    z_label: str = '',
    min_gp_obsnoise: float = None,
    resolution: int = 201,
    num_stds: int = 2,
    alpha: float = 0.3,
    color: chr = None,
    curve_label: str = 'mean',
    heatmap_cmap: colors.Colormap = None,
    show_legend_posterior: bool = True,
    show_legend_std: bool = False,
    show_legend_data: bool = True,
    legend_data_cmap: colors.Colormap = None,
    colorbar_label: str = None,
    title: str = None,
    render3D: bool = True,
) -> plt.Figure:
    """
    Fit the GP posterior to the input data and plot the mean and std as well as the data points.
    There are 3 options: 1D plot (infered by data dimensions), 2D plot

    .. note::
        If you want to have a tight layout, it is best to pass axes of a figure with `tight_layout=True` or
        `constrained_layout=True`.

    :param ax: axis of the figure to plot on, only in case of a 2-dim heat map plot provide 2 axis
    :param data_x: data to plot on the x-axis
    :param data_y: data to process and plot on the y-axis
    :param idcs_sel: selected indices of the input data
    :param data_x_min: explicit minimum value for the evaluation grid, by default this value is extracted from `data_x`
    :param data_x_max: explicit maximum value for the evaluation grid, by default this value is extracted from `data_x`
    :param x_label: label for x-axis
    :param y_label: label for y-axis
    :param z_label: label for z-axis (3D plot only)
    :param min_gp_obsnoise: set a minimal noise value (normalized) for the GP, if `None` the GP has no measurement noise
    :param resolution: number of samples for the input (corresponds to x-axis resolution of the plot)
    :param num_stds: number of standard deviations to plot around the mean
    :param alpha: transparency (alpha-value) for the std area
    :param color: color (e.g. 'k' for black), `None` invokes the default behavior
    :param curve_label: label for the mean curve (1D plot only)
    :param heatmap_cmap: color map forwarded to `render_heatmap()` (2D plot only), `None` to use Pyrado's default
    :param show_legend_posterior: flag if the legend entry for the posterior should be printed (affects mean and std)
    :param show_legend_std: flag if a legend entry for the std area should be printed
    :param show_legend_data: flag if a legend entry for the individual data points should be printed
    :param legend_data_cmap: color map for the sampled points, default is 'binary'
    :param colorbar_label: label for the color bar (2D plot only)
    :param title: title displayed above the figure, set to `None` to suppress the title
    :param render3D: use 3D rendering if possible
    :return: handle to the resulting figure
    """
    if data_x.ndim != 2:
        raise pyrado.ShapeErr(
            msg=
            "The GP's input data needs to be of shape num_samples x dim_input!"
        )
    data_x = data_x[:, idcs_sel]  # forget the rest
    dim_x = data_x.shape[1]  # samples are along axis 0

    if data_y.ndim != 2:
        raise pyrado.ShapeErr(given=data_y,
                              expected_match=to.Size([data_x.shape[0], 1]))

    if legend_data_cmap is None:
        legend_data_cmap = plt.get_cmap('binary')

    # Project to normalized input and standardized output
    if data_x_min is None or data_x_max is None:
        data_x_min, data_x_max = to.min(data_x, dim=0)[0], to.max(data_x,
                                                                  dim=0)[0]
    data_y_mean, data_y_std = to.mean(data_y, dim=0), to.std(data_y, dim=0)
    data_x = (data_x - data_x_min) / (data_x_max - data_x_min)
    data_y = (data_y - data_y_mean) / data_y_std

    # Create and fit the GP model
    gp = SingleTaskGP(data_x, data_y)
    if min_gp_obsnoise is not None:
        gp.likelihood.noise_covar.register_constraint(
            'raw_noise', GreaterThan(min_gp_obsnoise))
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    mll.train()
    fit_gpytorch_model(mll)
    print_cbt('Fitted the SingleTaskGP.', 'g')

    argmax_pmean_norm, argmax_pmean_val_stdzed = optimize_acqf(
        acq_function=PosteriorMean(gp),
        bounds=to.stack([to.zeros(dim_x), to.ones(dim_x)]),
        q=1,
        num_restarts=500,
        raw_samples=1000)
    # Project back
    argmax_posterior = argmax_pmean_norm * (data_x_max -
                                            data_x_min) + data_x_min
    argmax_pmean_val = argmax_pmean_val_stdzed * data_y_std + data_y_mean
    print_cbt(
        f'Converged to argmax of the posterior mean: {argmax_posterior.numpy()}',
        'g')

    mll.eval()
    gp.eval()

    if dim_x == 1:
        # Evaluation grid
        x_grid = np.linspace(min(data_x),
                             max(data_x),
                             resolution,
                             endpoint=True).flatten()
        x_grid = to.from_numpy(x_grid)

        # Mean and standard deviation of the surrogate model
        posterior = gp.posterior(x_grid)
        mean = posterior.mean.detach().flatten()
        std = to.sqrt(posterior.variance.detach()).flatten()

        # Project back from normalized input and standardized output
        x_grid = x_grid * (data_x_max - data_x_min) + data_x_min
        data_x = data_x * (data_x_max - data_x_min) + data_x_min
        data_y = data_y * data_y_std + data_y_mean
        mean = mean * data_y_std + data_y_mean
        std *= data_y_std  # double-checked with posterior.mvn.confidence_region()

        # Plot the curve
        plt.fill_between(x_grid.numpy(),
                         mean.numpy() - num_stds * std.numpy(),
                         mean.numpy() + num_stds * std.numpy(),
                         alpha=alpha,
                         color=color)
        ax.plot(x_grid.numpy(), mean.numpy(), color=color)

        # Plot the queried data points
        scat_plot = ax.scatter(data_x.numpy().flatten(),
                               data_y.numpy().flatten(),
                               marker='o',
                               c=np.arange(data_x.shape[0], dtype=np.int),
                               cmap=legend_data_cmap)

        if show_legend_data:
            scat_legend = ax.legend(
                *scat_plot.legend_elements(fmt='{x:.0f}'),  # integer formatter
                bbox_to_anchor=(0., 1.1, 1., -0.1),
                title='query points',
                ncol=data_x.shape[0],
                loc='upper center',
                mode='expand',
                borderaxespad=0.,
                handletextpad=-0.5)
            ax.add_artist(scat_legend)
            # Increase vertical space between subplots when printing the data labels
            # plt.tight_layout(pad=2.)  # ignore argument
            # plt.subplots_adjust(hspace=0.6)

        # Plot the argmax of the posterior mean
        # ax.scatter(argmax_posterior.item(), argmax_pmean_val, c='darkorange', marker='o', s=60, label='argmax')
        ax.axvline(argmax_posterior.item(),
                   c='darkorange',
                   lw=1.5,
                   label='argmax')

        if show_legend_posterior:
            ax.add_artist(ax.legend(loc='lower right'))

    elif dim_x == 2:
        # Create mesh grid matrices from x and y vectors
        # x0_grid = to.linspace(min(data_x[:, 0]), max(data_x[:, 0]), resolution)
        # x1_grid = to.linspace(min(data_x[:, 1]), max(data_x[:, 1]), resolution)
        x0_grid = to.linspace(0, 1, resolution)
        x1_grid = to.linspace(0, 1, resolution)
        x0_mesh, x1_mesh = to.meshgrid([x0_grid, x1_grid])
        x0_mesh, x1_mesh = x0_mesh.t(), x1_mesh.t(
        )  # transpose not necessary but makes identical mesh as np.meshgrid

        # Mean and standard deviation of the surrogate model
        x_test = to.stack([
            x0_mesh.reshape(resolution**2, 1),
            x1_mesh.reshape(resolution**2, 1)
        ], -1).squeeze(1)
        posterior = gp.posterior(
            x_test)  # identical to  gp.likelihood(gp(x_test))
        mean = posterior.mean.detach().reshape(resolution, resolution)
        std = to.sqrt(posterior.variance.detach()).reshape(
            resolution, resolution)

        # Project back from normalized input and standardized output
        data_x = data_x * (data_x_max - data_x_min) + data_x_min
        data_y = data_y * data_y_std + data_y_mean
        mean_raw = mean * data_y_std + data_y_mean
        std_raw = std * data_y_std

        if render3D:
            # Project back from normalized input and standardized output (custom for 3D)
            x0_mesh = x0_mesh * (data_x_max[0] - data_x_min[0]) + data_x_min[0]
            x1_mesh = x1_mesh * (data_x_max[1] - data_x_min[1]) + data_x_min[1]
            lower = mean_raw - num_stds * std_raw
            upper = mean_raw + num_stds * std_raw

            # Plot a 2D surface in 3D
            ax.plot_surface(x0_mesh.numpy(), x1_mesh.numpy(), mean_raw.numpy())
            ax.plot_surface(x0_mesh.numpy(),
                            x1_mesh.numpy(),
                            lower.numpy(),
                            color='r',
                            alpha=alpha)
            ax.plot_surface(x0_mesh.numpy(),
                            x1_mesh.numpy(),
                            upper.numpy(),
                            color='r',
                            alpha=alpha)
            ax.set_xlabel(x_label)
            ax.set_ylabel(y_label)
            ax.set_zlabel(z_label)

            # Plot the queried data points
            scat_plot = ax.scatter(data_x[:, 0].numpy(),
                                   data_x[:, 1].numpy(),
                                   data_y.numpy(),
                                   marker='o',
                                   c=np.arange(data_x.shape[0], dtype=np.int),
                                   cmap=legend_data_cmap)

            if show_legend_data:
                scat_legend = ax.legend(
                    *scat_plot.legend_elements(
                        fmt='{x:.0f}'),  # integer formatter
                    bbox_to_anchor=(0.05, 1.1, 0.95, -0.1),
                    loc='upper center',
                    ncol=data_x.shape[0],
                    mode='expand',
                    borderaxespad=0.,
                    handletextpad=-0.5)
                ax.add_artist(scat_legend)

            # Plot the argmax of the posterior mean
            x, y = argmax_posterior[0, 0], argmax_posterior[0, 1]
            ax.scatter(x,
                       y,
                       argmax_pmean_val,
                       c='darkorange',
                       marker='*',
                       s=60)
            # ax.plot((x, x), (y, y), (data_y.min(), data_y.max()), c='k', ls='--', lw=1.5)

        else:
            if not len(ax) == 4:
                raise pyrado.ShapeErr(
                    msg='Provide 4 axes! 2 heat maps and 2 color bars.')

            # Project back normalized input and standardized output (custom for 2D)
            x0_grid_raw = x0_grid * (data_x_max[0] -
                                     data_x_min[0]) + data_x_min[0]
            x1_grid_raw = x1_grid * (data_x_max[1] -
                                     data_x_min[1]) + data_x_min[1]

            # Plot a 2D image
            df_mean = pd.DataFrame(mean_raw.numpy(),
                                   columns=x0_grid_raw.numpy(),
                                   index=x1_grid_raw.numpy())
            render_heatmap(df_mean,
                           ax_hm=ax[0],
                           ax_cb=ax[1],
                           x_label=x_label,
                           y_label=y_label,
                           annotate=False,
                           fig_canvas_title='Returns',
                           tick_label_prec=2,
                           add_sep_colorbar=True,
                           cmap=heatmap_cmap,
                           colorbar_label=colorbar_label,
                           num_major_ticks_hm=3,
                           num_major_ticks_cb=2,
                           colorbar_orientation='horizontal')

            df_std = pd.DataFrame(std_raw.numpy(),
                                  columns=x0_grid_raw.numpy(),
                                  index=x1_grid_raw.numpy())
            render_heatmap(
                df_std,
                ax_hm=ax[2],
                ax_cb=ax[3],
                x_label=x_label,
                y_label=y_label,
                annotate=False,
                fig_canvas_title='Standard Deviations',
                tick_label_prec=2,
                add_sep_colorbar=True,
                cmap=heatmap_cmap,
                colorbar_label=colorbar_label,
                num_major_ticks_hm=3,
                num_major_ticks_cb=2,
                colorbar_orientation='horizontal',
                norm=colors.Normalize())  # explicitly instantiate a new norm

            # Plot the queried data points
            for i in [0, 2]:
                scat_plot = ax[i].scatter(data_x[:, 0].numpy(),
                                          data_x[:, 1].numpy(),
                                          marker='o',
                                          s=15,
                                          c=np.arange(data_x.shape[0],
                                                      dtype=np.int),
                                          cmap=legend_data_cmap)

                if show_legend_data:
                    scat_legend = ax[i].legend(
                        *scat_plot.legend_elements(
                            fmt='{x:.0f}'),  # integer formatter
                        bbox_to_anchor=(0., 1.1, 1., 0.05),
                        loc='upper center',
                        ncol=data_x.shape[0],
                        mode='expand',
                        borderaxespad=0.,
                        handletextpad=-0.5)
                    ax[i].add_artist(scat_legend)

            # Plot the argmax of the posterior mean
            ax[0].scatter(argmax_posterior[0, 0],
                          argmax_posterior[0, 1],
                          c='darkorange',
                          marker='*',
                          s=60)  # steelblue
            ax[2].scatter(argmax_posterior[0, 0],
                          argmax_posterior[0, 1],
                          c='darkorange',
                          marker='*',
                          s=60)  # steelblue
            # ax[0].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5)
            # ax[0].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5)
            # ax[2].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5)
            # ax[2].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5)

    else:
        raise pyrado.ValueErr(msg='Can only plot 1-dim or 2-dim data!')

    return plt.gcf()
Esempio n. 24
0
    def forward(self, feature_pyramid1, feature_pyramid2, actions=None):
        """Run the model."""
        context = None
        flow = None
        flow_up = None
        context_up = None
        flows = []

        # make sure that actions are provided iff network is configured for action use
        if actions is not None:
            assert self._action_channels == actions.shape[1]
        else:
            assert self._action_channels is None

        # Go top down through the levels to the second to last one to estimate flow.
        for level, (features1, features2) in reversed(
                list(enumerate(zip(feature_pyramid1, feature_pyramid2)))[1:]):

            # init flows with zeros for coarsest level if needed
            if self._shared_flow_decoder and flow_up is None:
                batch_size, height, width, _ = features1.shape.as_list()
                flow_up = torch.zeros([batch_size, height, width,
                                       2]).to(gpu_utils.device)
                if self._num_context_up_channels:
                    num_channels = int(self._num_context_up_channels *
                                       self._channel_multiplier)
                    context_up = torch.zeros(
                        [batch_size, height, width,
                         num_channels]).to(gpu_utils.device)

            # Warp features2 with upsampled flow from higher level.
            if flow_up is None or not self._use_feature_warp:
                warped2 = features2
            else:
                warp_up = uflow_utils.flow_to_warp(flow_up)
                warped2 = uflow_utils.resample(features2, warp_up)

            # Compute cost volume by comparing features1 and warped features2.
            features1_normalized, warped2_normalized = uflow_utils.normalize_features(
                [features1, warped2],
                normalize=self._normalize_before_cost_volume,
                center=self._normalize_before_cost_volume,
                moments_across_channels=True,
                moments_across_images=True)

            if self._use_cost_volume:
                cost_volume = uflow_utils.compute_cost_volume(
                    features1_normalized,
                    warped2_normalized,
                    max_displacement=4)
            else:
                concat_features = torch.cat(
                    [features1_normalized, warped2_normalized], dim=1)
                cost_volume = self._cost_volume_surrogate_convs[level](
                    concat_features)

            cost_volume = func.leaky_relu(
                cost_volume, negative_slope=self._leaky_relu_alpha)

            if self._shared_flow_decoder:
                # This will ensure to work for arbitrary feature sizes per level.
                conv_1x1 = self._1x1_shared_decoder[level]
                features1 = conv_1x1(features1)

            # Compute context and flow from previous flow, cost volume, and features1.
            if flow_up is None:
                x_in = torch.cat([cost_volume, features1], dim=1)
            else:
                if context_up is None:
                    x_in = torch.cat([flow_up, cost_volume, features1], dim=1)
                else:
                    x_in = torch.cat(
                        [context_up, flow_up, cost_volume, features1], dim=1)

            if self._action_channels is not None and self._action_channels > 0:
                # convert every entry in actions to a channel filled with this value and attach it to flow input
                B, _, H, W = features1.shape
                # additionally append xy position augmentation
                action_tensor = actions[:, :, None, None].repeat(1, 1, H, W)
                gy, gx = torch.meshgrid(
                    [torch.arange(H).float(),
                     torch.arange(W).float()])
                gx = gx.repeat(B, 1, 1, 1).to(gpu_utils.device)
                gy = gy.repeat(B, 1, 1, 1).to(gpu_utils.device)
                x_in = torch.cat([x_in, action_tensor, gx, gy], dim=1)

            # Use dense-net connections.
            x_out = None
            if self._shared_flow_decoder:
                # reuse the same flow decoder on all levels
                flow_layers = self._flow_layers
            else:
                flow_layers = self._flow_layers[level]
            for layer in flow_layers[:-1]:
                x_out = layer(x_in)
                x_in = torch.cat([x_in, x_out], dim=1)
            context = x_out

            flow = flow_layers[-1](context)

            # dropout full layer
            if self.training and self._drop_out_rate:
                maybe_dropout = (torch.rand([]) > self._drop_out_rate).type(
                    torch.get_default_dtype())
                # note that operation must not be inplace, otherwise autograd will fail pathetically
                context = context * maybe_dropout
                flow = flow * maybe_dropout

            if flow_up is not None and self._accumulate_flow:
                flow += flow_up

            # Upsample flow for the next lower level.
            flow_up = uflow_utils.upsample(flow, is_flow=True)
            if self._num_context_up_channels:
                context_up = self._context_up_layers[level](context)

            # Append results to list.
            flows.insert(0, flow)

        # Refine flow at level 1.
        # refinement = self._refine_model(torch.cat([context, flow], dim=1))
        refinement = torch.cat([context, flow], dim=1)
        for layer in self._refine_model:
            refinement = layer(refinement)

        # dropout refinement
        if self.training and self._drop_out_rate:
            maybe_dropout = (torch.rand([]) > self._drop_out_rate).type(
                torch.get_default_dtype())
            # note that operation must not be inplace, otherwise autograd will fail pathetically
            refinement = refinement * maybe_dropout

        refined_flow = flow + refinement
        flows[0] = refined_flow

        flows.insert(0, uflow_utils.upsample(flows[0], is_flow=True))
        flows.insert(0, uflow_utils.upsample(flows[0], is_flow=True))

        return flows
Esempio n. 25
0
def coords_grid(batch, ht, wd):
    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
    coords = torch.stack(coords[::-1], dim=0).float()
    return coords[None].repeat(batch, 1, 1, 1)
Esempio n. 26
0
def main(args):
    pyro.set_rng_seed(args.rng_seed)
    fig = plt.figure(figsize=(8, 16), constrained_layout=True)
    gs = GridSpec(4, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[2, 0])
    ax5 = fig.add_subplot(gs[3, 0])
    ax6 = fig.add_subplot(gs[1, 1])
    ax7 = fig.add_subplot(gs[2, 1])
    ax8 = fig.add_subplot(gs[3, 1])
    xlim = tuple(int(x) for x in args.x_lim.strip().split(','))
    ylim = tuple(int(x) for x in args.y_lim.strip().split(','))
    assert len(xlim) == 2
    assert len(ylim) == 2

    # 1. Plot samples drawn from BananaShaped distribution
    x1, x2 = torch.meshgrid(
        [torch.linspace(*xlim, 100),
         torch.linspace(*ylim, 100)])
    d = BananaShaped(args.param_a, args.param_b)
    p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1)))
    ax1.contourf(
        x1,
        x2,
        p,
        cmap='OrRd',
    )
    ax1.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='BananaShaped distribution: \nlog density')

    # 2. Run vanilla HMC
    logging.info('\nDrawing samples using vanilla HMC ...')
    mcmc = run_hmc(args, model)
    vanilla_samples = mcmc.get_samples()['x'].cpu().numpy()
    ax2.contourf(x1, x2, p, cmap='OrRd')
    ax2.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(vanilla HMC)')
    sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2)

    # 3(a). Fit a diagonal normal autoguide
    logging.info('\nFitting a DiagNormal autoguide ...')
    guide = AutoDiagonalNormal(model, init_scale=0.05)
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax3.contourf(x1, x2, p, cmap='OrRd')
    ax3.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(DiagNormal autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3)

    # 3(b). Draw samples using NeuTra HMC
    logging.info(
        '\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4)
    ax4.set(xlabel='x0',
            ylabel='x1',
            title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax5.contourf(x1, x2, p, cmap='OrRd')
    ax5.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5)

    # 4(a). Fit a BNAF autoguide
    logging.info('\nFitting a BNAF autoguide ...')
    guide = AutoNormalizingFlow(
        model, partial(iterated, args.num_flows, block_autoregressive))
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax6.contourf(x1, x2, p, cmap='OrRd')
    ax6.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(BNAF autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6)

    # 4(b). Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7)
    ax7.set(xlabel='x0',
            ylabel='x1',
            title='Posterior (warped) samples \n(BNAF + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax8.contourf(x1, x2, p, cmap='OrRd')
    ax8.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior (transformed) \n(BNAF + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8)

    plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))
Esempio n. 27
0
 def construct_image(c):
     x, y = torch.meshgrid(torch.linspace(- 5, 5, 64), - torch.linspace(- 5, 5, 64))
     b = torch.stack([x ** i * y ** j for i in range(5) for j in range(5 - i)])
     return torch.einsum('ij,jkl->ikl', c, b).view(- 1, 64, 64)
Esempio n. 28
0
def get_grid(size):
    grids = torch.meshgrid([torch.arange(s) for s in size])
    return torch.stack(grids, 0).float()
Esempio n. 29
0
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

device = [0, 1]

from model import CreateModel
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
from tqdm import tqdm

torch.manual_seed(0)

# Data
tick = torch.arange(0., 256.) / 255.  #0~255
R, G, B = torch.meshgrid(tick, tick, tick)
# gamma 1.1
X = torch.stack(
    (torch.reshape(R,
                   (-1, 1)), torch.reshape(G,
                                           (-1, 1)), torch.reshape(B,
                                                                   (-1, 1))),
    dim=1)
X = torch.unsqueeze(X, 2)
Y = torch.pow(X, 1 / 1.1)

X -= 0.5
Y -= 0.5

dataset = Data.TensorDataset(X, Y)
data_iter = Data.DataLoader(dataset,
Esempio n. 30
0
def from_matrix( ranges_i, ranges_j, keep ) :
    r"""Turns a boolean matrix into a KeOps-friendly **ranges** argument.

    This routine is a helper for the **block-sparse** reduction mode of KeOps,
    allowing you to turn clustering information (**ranges_i**,
    **ranges_j**) and a cluster-to-cluster boolean mask (**keep**) 
    into integer tensors of indices that can be used to schedule the KeOps routines.

    Suppose that you're working with variables :math:`x_i`  (:math:`i \in [0,10^6)`),
    :math:`y_j`  (:math:`j \in [0,10^7)`), and that you want to compute a KeOps reduction
    over indices :math:`i` or :math:`j`: Instead of performing the full 
    kernel dot product (:math:`10^6 \cdot 10^7 = 10^{13}` operations!), 
    you may want to restrict yourself to
    interactions between points :math:`x_i` and :math:`y_j` that are "close" to each other.

    With KeOps, the simplest way of doing so is to:
    
    1. Compute cluster labels for the :math:`x_i`'s and :math:`y_j`'s, using e.g. 
       the :func:`grid_cluster` method.
    2. Compute the ranges (**ranges_i**, **ranges_j**) and centroids associated 
       to each cluster, using e.g. the :func:`cluster_ranges_centroids` method.
    3. Sort the tensors ``x_i`` and ``y_j`` with :func:`sort_clusters` to make sure that the
       clusters are stored contiguously in memory (this step is **critical** for performance on GPUs).

    At this point:
        - the :math:`k`-th cluster of :math:`x_i`'s is given by ``x_i[ ranges_i[k,0]:ranges_i[k,1], : ]``, for :math:`k \in [0,M)`, 
        - the :math:`\ell`-th cluster of :math:`y_j`'s is given by ``y_j[ ranges_j[l,0]:ranges_j[l,1], : ]``, for :math:`\ell \in [0,N)`.
    
    4. Compute the :math:`(M,N)` matrix **dist** of pairwise distances between cluster centroids.
    5. Apply a threshold on **dist** to generate a boolean matrix ``keep = dist < threshold``.
    6. Define a KeOps reduction ``my_genred = Genred(..., axis = 0 or 1)``, as usual.
    7. Compute the block-sparse reduction through
       ``result = my_genred(x_i, y_j, ranges = from_matrix(ranges_i,ranges_j,keep) )``

    :func:`from_matrix` is thus the routine that turns a **high-level description**
    of your block-sparse computation (cluster ranges + boolean matrix)
    into a set of **integer tensors** (the **ranges** optional argument), 
    used by KeOps to schedule computations on the GPU.

    Args:
        ranges_i ((M,2) IntTensor): List of :math:`[\text{start}_k,\text{end}_k)` indices.
            For :math:`k \in [0,M)`, the :math:`k`-th cluster of ":math:`i`" variables is
            given by ``x_i[ ranges_i[k,0]:ranges_i[k,1], : ]``, etc.
        ranges_j ((N,2) IntTensor): List of :math:`[\text{start}_\ell,\text{end}_\ell)` indices.
            For :math:`\ell \in [0,N)`, the :math:`\ell`-th cluster of ":math:`j`" variables is
            given by ``y_j[ ranges_j[l,0]:ranges_j[l,1], : ]``, etc.
        keep ((M,N) BoolTensor): 
            If the output ``ranges`` of :func:`from_matrix` is used in a KeOps reduction,
            we will only compute and reduce the terms associated to pairs of "points"
            :math:`x_i`, :math:`y_j` in clusters :math:`k` and :math:`\ell`
            if ``keep[k,l] == 1``.

    Returns:
        A 6-uple of LongTensors that can be used as an optional **ranges**
        argument of :class:`torch.Genred <pykeops.torch.Genred>`. See the documentation of :class:`torch.Genred <pykeops.torch.Genred>` for reference.

    Example:
        >>> r_i = torch.IntTensor( [ [2,5], [7,12] ] )          # 2 clusters: X[0] = x_i[2:5], X[1] = x_i[7:12]
        >>> r_j = torch.IntTensor( [ [1,4], [4,9], [20,30] ] )  # 3 clusters: Y[0] = y_j[1:4], Y[1] = y_j[4:9], Y[2] = y_j[20:30]
        >>> x,y = torch.Tensor([1., 0.]), torch.Tensor([1.5, .5, 2.5])  # dummy "centroids"
        >>> dist = (x[:,None] - y[None,:])**2
        >>> keep = (dist <= 1)                                  # (2,3) matrix
        >>> print(keep)
        tensor([[1, 1, 0],
                [0, 1, 0]], dtype=torch.uint8)
        --> X[0] interacts with Y[0] and Y[1], X[1] interacts with Y[1]
        >>> (ranges_i,slices_i,redranges_j, ranges_j,slices_j,redranges_i) = from_matrix(r_i,r_j,keep)
        --> (ranges_i,slices_i,redranges_j) will be used for reductions with respect to "j" (axis=1)
        --> (ranges_j,slices_j,redranges_i) will be used for reductions with respect to "i" (axis=0)

        Information relevant if **axis** = 1:

        >>> print(ranges_i)  # = r_i
        tensor([[ 2,  5],
                [ 7, 12]], dtype=torch.int32)
        --> Two "target" clusters in a reduction wrt. j
        >>> print(slices_i)  
        tensor([2, 3], dtype=torch.int32)
        --> X[0] is associated to redranges_j[0:2]
        --> X[1] is associated to redranges_j[2:3]
        >>> print(redranges_j)
        tensor([[1, 4],
                [4, 9],
                [4, 9]], dtype=torch.int32)
        --> For X[0], i in [2,3,4],       we'll reduce over j in [1,2,3] and [4,5,6,7,8]
        --> For X[1], i in [7,8,9,10,11], we'll reduce over j in [4,5,6,7,8]


        Information relevant if **axis** = 0:

        >>> print(ranges_j)
        tensor([[ 1,  4],
                [ 4,  9],
                [20, 30]], dtype=torch.int32)
        --> Three "target" clusters in a reduction wrt. i
        >>> print(slices_j)
        tensor([1, 3, 3], dtype=torch.int32)
        --> Y[0] is associated to redranges_i[0:1]
        --> Y[1] is associated to redranges_i[1:3]
        --> Y[2] is associated to redranges_i[3:3] = no one...
        >>> print(redranges_i)
        tensor([[ 2,  5],
                [ 2,  5],
                [ 7, 12]], dtype=torch.int32)
        --> For Y[0], j in [1,2,3],     we'll reduce over i in [2,3,4]
        --> For Y[1], j in [4,5,6,7,8], we'll reduce over i in [2,3,4] and [7,8,9,10,11]
        --> For Y[2], j in [20,21,...,29], there is no reduction to be done
    """
    I, J = torch.meshgrid( (torch.arange(0, keep.shape[0]), torch.arange(0,keep.shape[1])) )
    redranges_i = ranges_i[ I.t()[keep.t()] ]  # Use PyTorch indexing to "stack" copies of ranges_i[...]
    redranges_j = ranges_j[ J[keep] ]
    slices_i = keep.sum(1).cumsum(0).int()  # slice indices in the "stacked" array redranges_j
    slices_j = keep.sum(0).cumsum(0).int()  # slice indices in the "stacked" array redranges_i
    return (ranges_i, slices_i, redranges_j, ranges_j, slices_j, redranges_i)
Esempio n. 31
0
def make_material_atlas(image: torch.Tensor, faces_verts_uvs: torch.Tensor,
                        texture_size: int) -> torch.Tensor:
    r"""
    Given a single texture image and the uv coordinates for all the
    face vertices, create a square texture map per face using
    the formulation from [1].

    For a triangle with vertices (v0, v1, v2) we can create a barycentric coordinate system
    with the x axis being the vector (v1 - v0) and the y axis being the vector (v2 - v0).
    The barycentric coordinates range from [0, 1] in the +x and +y direction so this creates
    a triangular texture space with vertices at (0, 1), (0, 0) and (1, 0).

    The per face texture map is of shape (texture_size, texture_size, 3)
    which is a square. To map a triangular texture to a square grid, each
    triangle is parametrized as follows (e.g. R = texture_size = 3):

    The triangle texture is first divided into RxR = 9 subtriangles which each
    map to one grid cell. The numbers in the grid cells and triangles show the mapping.

    ..code-block::python

        Triangular Texture Space:

              1
                |\
                |6 \
                |____\
                |\  7 |\
                |3 \  |4 \
                |____\|____\
                |\ 8  |\  5 |\
                |0 \  |1 \  |2 \
                |____\|____\|____\
               0                   1

        Square per face texture map:

               R ____________________
                |      |      |      |
                |  6   |  7   |  8   |
                |______|______|______|
                |      |      |      |
                |  3   |  4   |  5   |
                |______|______|______|
                |      |      |      |
                |  0   |  1   |  2   |
                |______|______|______|
               0                      R


    The barycentric coordinates of each grid cell are calculated using the
    xy coordinates:

    ..code-block::python

            The cartesian coordinates are:

            Grid 1:

               R ____________________
                |      |      |      |
                |  20  |  21  |  22  |
                |______|______|______|
                |      |      |      |
                |  10  |  11  |  12  |
                |______|______|______|
                |      |      |      |
                |  00  |  01  |  02  |
                |______|______|______|
               0                      R

            where 02 means y = 0, x = 2

        Now consider this subset of the triangle which corresponds to
        grid cells 0 and 8:

        ..code-block::python

            1/R  ________
                |\    8  |
                |  \     |
                | 0   \  |
                |_______\|
               0          1/R

        The centroids of the triangles are:
            0: (1/3, 1/3) * 1/R
            8: (2/3, 2/3) * 1/R

    For each grid cell we can now calculate the centroid `(c_y, c_x)`
    of the corresponding texture triangle:
        - if `(x + y) < R`, then offsett the centroid of
            triangle 0 by `(y, x) * (1/R)`
        - if `(x + y) > R`, then offset the centroid of
            triangle 8 by `((R-1-y), (R-1-x)) * (1/R)`.

    This is equivalent to updating the portion of Grid 1
    above the diagnonal, replacing `(y, x)` with `((R-1-y), (R-1-x))`:

    ..code-block::python

              R _____________________
                |      |      |      |
                |  20  |  01  |  00  |
                |______|______|______|
                |      |      |      |
                |  10  |  11  |  10  |
                |______|______|______|
                |      |      |      |
                |  00  |  01  |  02  |
                |______|______|______|
               0                      R

    The barycentric coordinates (w0, w1, w2) are then given by:

    ..code-block::python

        w0 = c_x
        w1 = c_y
        w2 = 1- w0 - w1

    Args:
        image: FloatTensor of shape (H, W, 3)
        faces_verts_uvs: uv coordinates for each vertex in each face  (F, 3, 2)
        texture_size: int

    Returns:
        atlas: a FloatTensor of shape (F, texture_size, texture_size, 3) giving a
            per face texture map.

    [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
        3D Reasoning', ICCV 2019
    """
    R = texture_size
    device = faces_verts_uvs.device
    rng = torch.arange(R, device=device)

    # Meshgrid returns (row, column) i.e (Y, X)
    # Change order to (X, Y) to make the grid.
    Y, X = torch.meshgrid(rng, rng)
    # pyre-fixme[28]: Unexpected keyword argument `axis`.
    grid = torch.stack([X, Y], axis=-1)  # (R, R, 2)

    # Grid cells below the diagonal: x + y < R.
    below_diag = grid.sum(-1) < R

    # map a [0, R] grid -> to a [0, 1] barycentric coordinates of
    # the texture triangle centroids.
    bary = torch.zeros((R, R, 3), device=device)  # (R, R, 3)
    slc = torch.arange(2, device=device)[:, None]
    # w0, w1
    bary[below_diag, slc] = ((grid[below_diag] + 1.0 / 3.0) / R).T
    # w0, w1 for above diagonal grid cells.
    # pyre-fixme[16]: `float` has no attribute `T`.
    bary[~below_diag,
         slc] = (((R - 1.0 - grid[~below_diag]) + 2.0 / 3.0) / R).T
    # w2 = 1. - w0 - w1
    bary[..., -1] = 1 - bary[..., :2].sum(dim=-1)

    # Calculate the uv position in the image for each pixel
    # in the per face texture map
    # (F, 1, 1, 3, 2) * (R, R, 3, 1) -> (F, R, R, 3, 2) -> (F, R, R, 2)
    uv_pos = (faces_verts_uvs[:, None, None] * bary[..., None]).sum(-2)

    # bi-linearly interpolate the textures from the images
    # using the uv coordinates given by uv_pos.
    textures = _bilinear_interpolation_vectorized(image, uv_pos)

    return textures
    def forward(self, x):
        x = self.model(x)
        return x


model = NETWORK()

# model.load_state_dict(torch.load("E:/python/Jupyter/TrainedModels_Saving/sinhtcosx_for_spare_use.pth"))
model.load_state_dict(torch.load("./nn_proportional_delay_PDE.pth"))

#plot 3-D figure
N_t, N_x = 500, 100
t = torch.linspace(0, 1.0, N_t)
x = torch.linspace(0, 2 * pi, N_x)
t, x = torch.meshgrid(t, x)
T_X = torch.zeros(N_t * N_x, 2)
for i in range(N_t):
    for j in range(N_x):
        T_X[i * N_x + j, :] = torch.tensor([t[i, j], x[i, j]])
u = model(T_X).detach().numpy().reshape(N_t, N_x)

t = t.numpy()
x = x.numpy()
u_acc = np.sinh(t) * np.cos(x)

fig = plt.figure()
fig.set_size_inches(10, 4)
ax0 = fig.add_subplot(1, 2, 1, projection='3d')
ax0.plot_surface(t, x, u, rstride=1, cstride=1, cmap='rainbow')
# ax0.grid(b = False)
Esempio n. 33
0
    def forward(self, box_cls_all, box_reg_all, centerness_all, boxes_all):
        device = box_cls_all.device
        boxes_per_image = [len(box) for box in boxes_all]
        cls = box_cls_all.split(boxes_per_image, dim=0)
        reg = box_reg_all.split(boxes_per_image, dim=0)
        center = centerness_all.split(boxes_per_image, dim=0)

        results = []
        for box_cls, box_regression, centerness, boxes in zip(cls, reg, center, boxes_all):
            N, C, H, W = box_cls.shape
            # put in the same format as locations
            box_cls = box_cls.permute(0, 2, 3, 1).reshape(N, -1, self.num_classes).sigmoid()
            box_regression = box_regression.permute(0, 2, 3, 1).reshape(N, -1, 4)
            centerness = centerness.permute(0, 2, 3, 1).reshape(N, -1).sigmoid()

            # multiply the classification scores with centerness scores
            box_cls = box_cls * centerness[:, :, None]
            _boxes = boxes.bbox
            size = boxes.size
            boxes_scores = boxes.get_field("scores")
            results_per_image = [boxes]
            for i in range(N):
                box = _boxes[i]
                boxes_score = boxes_scores[i]
                per_box_cls = box_cls[i]
                per_box_cls_max, per_box_cls_inds = per_box_cls.max(dim=0)

                per_class = torch.range(2, 1 + self.num_classes, dtype=torch.long, device=device)

                per_box_regression = box_regression[i]
                per_box_regression = per_box_regression[per_box_cls_inds]

                x_step = 1.0
                y_step = 1.0
                shifts_x = torch.arange(
                    0, self.m, step=x_step,
                    dtype=torch.float32, device=device
                ) + x_step / 2
                shifts_y = torch.arange(
                    0, self.m, step=y_step,
                    dtype=torch.float32, device=device
                ) + y_step / 2
                shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
                shift_x = shift_x.reshape(-1)
                shift_y = shift_y.reshape(-1)
                locations = torch.stack((shift_x, shift_y), dim=1)
                per_locations = locations[per_box_cls_inds]

                _x1 = per_locations[:, 0] - per_box_regression[:, 0]
                _y1 = per_locations[:, 1] - per_box_regression[:, 1]
                _x2 = per_locations[:, 0] + per_box_regression[:, 2]
                _y2 = per_locations[:, 1] + per_box_regression[:, 3]

                _x1 = _x1 / self.m * (box[2] - box[0]) + box[0]
                _y1 = _y1 / self.m * (box[3] - box[1]) + box[1]
                _x2 = _x2 / self.m * (box[2] - box[0]) + box[0]
                _y2 = _y2 / self.m * (box[3] - box[1]) + box[1]

                detections = torch.stack([_x1, _y1, _x2, _y2], dim=-1)

                boxlist = BoxList(detections, size, mode="xyxy")
                boxlist.add_field("labels", per_class)
                boxlist.add_field("scores", torch.sqrt(torch.sqrt(per_box_cls_max) * boxes_score))
                boxlist = boxlist.clip_to_image(remove_empty=False)
                boxlist = remove_small_boxes(boxlist, 0)
                results_per_image.append(boxlist)

            results_per_image = cat_boxlist(results_per_image)
            results.append(results_per_image)

        return results
Esempio n. 34
0
def optimize(ux0, uy0, im1, im2, maxIter, lambda_r, to1, theta, device):
    eps = 0.0000001
    Ix, Iy = centralFiniteDifference(im1)
    It = im1 - im2

    a11 = Ix * Ix
    a12 = Ix * Iy
    a22 = Iy * Iy

    t1 = Ix * (It - Ix * ux0 - Iy * uy0)
    t2 = Iy * (It - Ix * ux0 - Iy * uy0)

    h, w = im1.size()

    vx = torch.zeros((h, w)).to(device).double()
    vy = torch.zeros((h, w)).to(device).double()
    bx = torch.zeros((h, w)).to(device).double()
    by = torch.zeros((h, w)).to(device).double()
    ux = torch.zeros((h, w)).to(device).double()
    uy = torch.zeros((h, w)).to(device).double()

    X, Y = torch.meshgrid([torch.arange(0, h), torch.arange(0, w)])
    # G = 2 * (torch.cos(PI * X / w + PI * Y / h) - 2)
    # G = G.to(device).double()

    X, Y = torch.meshgrid(torch.linspace(0, h - 1, h),
                          torch.linspace(0, w - 1, w))
    X, Y = X.cuda(), Y.cuda()
    G = torch.cos(math.pi * X / h) + torch.cos(math.pi * Y / w) - 2
    # G = G.unsqueeze(0).repeat(N, 1, 1, 1)

    for i in range(maxIter):
        tempx = ux
        tempy = uy

        h1 = theta * (vx - bx) - t1
        h2 = theta * (vy - by) - t2

        ux = ((a22 + theta) * h1 - a12 * h2) / ((a11 + theta) *
                                                (a22 + theta) - a12 * a12)
        uy = ((a11 + theta) * h2 - a12 * h1) / ((a11 + theta) *
                                                (a22 + theta) - a12 * a12)

        # vx = (idct2(dct2(theta * (ux + bx)) / (theta + lambda_r * G * G)))
        # vy = (idct2(dct2(theta * (uy + by)) / (theta + lambda_r * G * G)))

        vx = (tdct.idct_2d(
            tdct.dct_2d(theta * (ux + bx)) / (theta + lambda_r * G * G)))
        vy = (tdct.idct_2d(
            tdct.dct_2d(theta * (uy + by)) / (theta + lambda_r * G * G)))

        bx = bx + ux - vx
        by = by + uy - vy

        # t1 = Ix * (It - Ix * ux - Iy * uy)
        # t2 = Iy * (It - Ix * ux - Iy * uy)

        stopx = torch.sum(
            torch.abs(ux - tempx)) / (torch.sum(torch.abs(tempx)) + eps)
        stopy = torch.sum(
            torch.abs(uy - tempy)) / (torch.sum(torch.abs(tempy)) + eps)
        # print(i, stopx, stopy)
        if stopx < to1 and stopy < to1:
            print('iterate {} times, stop due to converge to tolerance'.format(
                i))
            break

    if i == maxIter - 1:
        print('iterate {} times, stop due to reach max iteration'.format(i))
    return ux, uy