def findLR(self, model, optimizer, writer, start_lr=1e-7, end_lr=10, num_iters=50): model.train() losses = [] lrs = np.logspace(np.log10(start_lr), np.log10(end_lr), num_iters) for lr in lrs: # Update LR for group in optimizer.param_groups: group['lr'] = lr batch = next(iter(self.data_loaders[0])) input_images, depthGT, maskGT = utils.unpack_batch_fixed(batch, self.cfg.device) # ------ define ground truth------ XGT, YGT = torch.meshgrid([torch.arange(self.cfg.outH), # [H,W] torch.arange(self.cfg.outW)]) # [H,W] XGT, YGT = XGT.float(), YGT.float() XYGT = torch.cat([ XGT.repeat([self.cfg.outViewN, 1, 1]), YGT.repeat([self.cfg.outViewN, 1, 1])], dim=0) #[2V,H,W] XYGT = XYGT.unsqueeze(dim=0).to(self.cfg.device) #[1,2V,H,W] with torch.set_grad_enabled(True): optimizer.zero_grad() XYZ, maskLogit = model(input_images) XY = XYZ[:, :self.cfg.outViewN * 2, :, :] depth = XYZ[:, self.cfg.outViewN * 2:self.cfg.outViewN * 3, :, :] mask = (maskLogit > 0).byte() # ------ Compute loss ------ loss_XYZ = self.l1(XY, XYGT) loss_XYZ += self.l1(depth.masked_select(mask), depthGT.masked_select(mask)) loss_mask = self.sigmoid_bce(maskLogit, maskGT) loss = loss_mask + self.cfg.lambdaDepth * loss_XYZ # Update weights loss.backward() # True Weight decay if self.cfg.trueWD is not None: for group in optimizer.param_groups: for param in group['params']: param.data = param.data.add( -self.cfg.trueWD * group['lr'], param.data) optimizer.step() losses.append(loss.item()) fig, ax = plt.subplots() ax.plot(lrs, losses) ax.set_xlabel('learning rate') ax.set_ylabel('loss') ax.set_xscale('log') writer.add_figure('findLR', fig)
def _val_on_epoch(self, model): model.eval() data_loader = self.data_loaders[1] running_loss_XYZ = 0.0 running_loss_mask = 0.0 running_loss = 0.0 for batch in data_loader: input_images, depthGT, maskGT = utils.unpack_batch_fixed(batch, self.cfg.device) # ------ define ground truth------ XGT, YGT = torch.meshgrid([ torch.arange(self.cfg.outH), # [H,W] torch.arange(self.cfg.outW)]) # [H,W] XGT, YGT = XGT.float(), YGT.float() XYGT = torch.cat([ XGT.repeat([self.cfg.outViewN, 1, 1]), YGT.repeat([self.cfg.outViewN, 1, 1])], dim=0) #[2V,H,W] XYGT = XYGT.unsqueeze(dim=0).to(self.cfg.device) # [1,2V,H,W] with torch.set_grad_enabled(False): XYZ, maskLogit = model(input_images) XY = XYZ[:, :self.cfg.outViewN * 2, :, :] depth = XYZ[:, self.cfg.outViewN * 2:self.cfg.outViewN*3,:,:] mask = (maskLogit > 0).byte() # ------ Compute loss ------ loss_XYZ = self.l1(XY, XYGT) loss_XYZ += self.l1(depth.masked_select(mask), depthGT.masked_select(mask)) loss_mask = self.sigmoid_bce(maskLogit, maskGT) loss = loss_mask + self.cfg.lambdaDepth * loss_XYZ running_loss_XYZ += loss_XYZ.item() * input_images.size(0) running_loss_mask += loss_mask.item() * input_images.size(0) running_loss += loss.item() * input_images.size(0) epoch_loss_XYZ = running_loss_XYZ / len(data_loader.dataset) epoch_loss_mask = running_loss_mask / len(data_loader.dataset) epoch_loss = running_loss / len(data_loader.dataset) print(f"\tVal loss: {epoch_loss}") return {"epoch_loss_XYZ": epoch_loss_XYZ, "epoch_loss_mask": epoch_loss_mask, "epoch_loss": epoch_loss, }
def grid_anchors(self, grid_sizes): anchors = [] for size, stride, base_anchors in zip( grid_sizes, self.strides, self.cell_anchors ): grid_height, grid_width = size device = base_anchors.device shifts_x = torch.arange( 0, grid_width * stride, step=stride, dtype=torch.float32, device=device ) shifts_y = torch.arange( 0, grid_height * stride, step=stride, dtype=torch.float32, device=device ) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) anchors.append( (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) ) return anchors
def get_grid_coords_3d(self, z, y, x, coord_dim=-1): z, y, x = torch.meshgrid(z, y, x) coords = torch.stack([x, y, z], dim=coord_dim) return coords
def _make_grid(nx=20, ny=20): yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
def __call__(self, pred_hm, pred_wh, pred_centerness, heatmap, box_target, centerness, wh_weight, hm_weight): """ Args: pred_hm: tensor, (batch, 80, h, w). pred_wh: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). pred_centerness: tensor or None, (batch, 1, h, w). heatmap: tensor, (batch, 80, h, w). box_target: tensor, (batch, 4, h, w) or (batch, 80 * 4, h, w). centerness: tensor or None, (batch, 1, h, w). wh_weight: tensor or None, (batch, 80, h, w). Returns: """ if every_n_local_step(100): pred_hm_summary = torch.clamp(torch.sigmoid(pred_hm), min=1e-4, max=1 - 1e-4) gt_hm_summary = heatmap.clone() if self.fovea_hm: if not self.only_merge: pred_ctn_summary = torch.clamp( torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4) add_feature_summary( 'centernet/centerness', pred_ctn_summary.detach().cpu().numpy(), type='f') add_feature_summary( 'centernet/merge', (pred_ctn_summary * pred_hm_summary).detach().cpu().numpy(), type='max') add_feature_summary('centernet/gt_centerness', centerness.detach().cpu().numpy(), type='f') add_feature_summary('centernet/gt_merge', (centerness * gt_hm_summary).detach().cpu().numpy(), type='max') add_feature_summary('centernet/heatmap', pred_hm_summary.detach().cpu().numpy()) add_feature_summary('centernet/gt_heatmap', gt_hm_summary.detach().cpu().numpy()) H, W = pred_hm.shape[2:] if not self.fovea_hm: pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4) hm_weight = None if self.ct_version else hm_weight hm_loss = ct_focal_loss(pred_hm, heatmap, hm_weight=hm_weight) * self.hm_weight centerness_loss = hm_loss.new_tensor([0.]) merge_loss = hm_loss.new_tensor([0.]) else: care_mask = (heatmap >= 0).float() avg_factor = torch.sum(heatmap > 0).float().item() + 1e-6 if not self.only_merge: hm_loss = py_sigmoid_focal_loss( pred_hm, heatmap, care_mask, reduction='sum') / avg_factor * self.hm_weight pred_centerness = torch.clamp(torch.sigmoid(pred_centerness), min=1e-4, max=1 - 1e-4) centerness_loss = ct_focal_loss( pred_centerness, centerness, gamma=2.) * self.ct_weight merge_loss = ct_focal_loss( torch.clamp(torch.sigmoid(pred_hm) * pred_centerness, min=1e-4, max=1 - 1e-4), heatmap * centerness, weight=(heatmap >= 0).float()) * self.merge_weight else: hm_loss = pred_hm.new_tensor([0.]) centerness_loss = pred_hm.new_tensor([0.]) merge_loss = ct_focal_loss( torch.clamp(torch.sigmoid(pred_hm), min=1e-4, max=1 - 1e-4), heatmap * centerness, weight=(heatmap >= 0).float()) * self.merge_weight if not self.wh_agnostic: pred_wh = pred_wh.view(pred_wh.size(0) * pred_hm.size(1), 4, H, W) box_target = box_target.view( box_target.size(0) * pred_hm.size(1), 4, H, W) mask = wh_weight.view(-1, H, W) avg_factor = mask.sum() + 1e-4 if self.base_loc is None: base_step = self.down_ratio shifts_x = torch.arange(0, (W - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shifts_y = torch.arange(0, (H - 1) * base_step + 1, base_step, dtype=torch.float32, device=heatmap.device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) self.base_loc = torch.stack((shift_x, shift_y), dim=0) # (2, h, w) # (batch, h, w, 4) pred_boxes = torch.cat((self.base_loc - pred_wh[:, [0, 1]], self.base_loc + pred_wh[:, [2, 3]]), dim=1).permute(0, 2, 3, 1) # (batch, h, w, 4) boxes = box_target.permute(0, 2, 3, 1) wh_loss = giou_loss(pred_boxes, boxes, mask, avg_factor=avg_factor) * self.giou_weight return hm_loss, wh_loss, centerness_loss, merge_loss
def test_meshgrid(self): x = torch.ones(3, requires_grad=True) y = torch.zeros(4, requires_grad=True) z = torch.ones(5, requires_grad=True) self.assertONNX(lambda x, y, z: torch.meshgrid(x, y, z), (x, y, z))
def warp_image_torch(img, H, out_shape=None): """Apply an homography to a torch Tensor Arguments: img: Tensor of shape (B,C,H,W) or (C,H,W) H: Tensor of shape (B,3,3) or (3,3), the homography out_shape: Tuple, the wanted shape of the out image Returns: A Tensor of shape (B) x (out_shape) or (B) x (img.shape), the warped image Raises: ValueError: If img and H batch sizes are different """ if out_shape is None: out_shape = img.shape[-2:] if len(img.shape) < 4: img = img[None] if len(H.shape) < 3: H = H[None] if img.shape[0] != H.shape[0]: raise ValueError( "batch size of images ({}) do not match the batch size of homographies ({})" .format(img.shape[0], H.shape[0])) batchsize = img.shape[0] # create grid for interpolation (in frame coordinates) y, x = torch.meshgrid([ torch.linspace(-0.5, 0.5, steps=out_shape[-2]), torch.linspace(-0.5, 0.5, steps=out_shape[-1]), ]) x = x.to(img.device) y = y.to(img.device) x, y = x.flatten(), y.flatten() # append ones for homogeneous coordinates xy = torch.stack([x, y, torch.ones_like(x)]) xy = xy.repeat([batchsize, 1, 1]) # shape: (B, 3, N) # warp points to model coordinates xy_warped = torch.matmul(H, xy) # H.bmm(xy) xy_warped, z_warped = xy_warped.split(2, dim=1) # we multiply by 2, since our homographies map to # coordinates in the range [-0.5, 0.5] (the ones in our GT datasets) xy_warped = 2.0 * xy_warped / (z_warped + 1e-8) x_warped, y_warped = torch.unbind(xy_warped, dim=1) # build grid grid = torch.stack( [ x_warped.view(batchsize, *out_shape[-2:]), y_warped.view(batchsize, *out_shape[-2:]), ], dim=-1, ) # sample warped image warped_img = torch.nn.functional.grid_sample(img, grid, mode="bilinear", padding_mode="zeros") if hasnan(warped_img): print("nan value in warped image! set to zeros") warped_img[isnan(warped_img)] = 0 return warped_img
def render(seed=None, xlim=[-1.0, 1.0], ylim=None, xres=1024, yres=None, units=16, depth=8, hidden_std=1.0, output_std=1.0, channels=3, radius=True, bias=True, z=None, device='cpu'): if device not in get_devices(): raise RuntimeError('Device {} not in available devices: {}'.format( device, ', '.join(get_devices()))) cpu_rng_state = torch.get_rng_state() cuda_rng_states = [] if torch.cuda.is_available(): cuda_rng_states = [torch.cuda.get_rng_state(idx) for idx in range(torch.cuda.device_count())] if seed is None: seed = random.Random().randint(0, 2 ** 32 - 1) torch.cuda.manual_seed_all(seed) torch.manual_seed(seed) if ylim is None: ylim = copy.copy(xlim) if yres is None: yxscale = float(ylim[1] - ylim[0]) / (xlim[1] - xlim[0]) yres = int(yxscale * xres) x = torch.linspace(xlim[0], xlim[1], xres, device=device) y = torch.linspace(ylim[0], ylim[1], yres, device=device) meshgrid_kwargs = {} if inspect.signature(torch.meshgrid).parameters.get('indexing'): meshgrid_kwargs['indexing'] = 'ij' grid = torch.meshgrid((y, x), **meshgrid_kwargs) inputs = torch.cat((grid[0].flatten().unsqueeze(1), grid[1].flatten().unsqueeze(1)), -1) if radius: inputs = torch.cat((inputs, torch.norm(inputs, 2, 1).unsqueeze(1)), -1) if z is not None: zrep = torch.tensor(z, dtype=inputs.dtype, device=device).repeat((inputs.shape[0], 1)) inputs = torch.cat((inputs, zrep), -1) n_hidden_units = [units] * depth activations = inputs for units in n_hidden_units: if bias: bias_array = torch.ones((activations.shape[0], 1), device=device) activations = torch.cat((bias_array, activations), -1) hidden_layer_weights = torch.randn((activations.shape[1], units), device=device) * hidden_std activations = torch.tanh(torch.mm(activations, hidden_layer_weights)) if bias: bias_array = torch.ones((activations.shape[0], 1), device=device) activations = torch.cat((bias_array, activations), -1) output_layer_weights = torch.randn((activations.shape[1], channels), device=device) * output_std output = torch.sigmoid(torch.mm(activations, output_layer_weights)) output = output.reshape((yres, xres, channels)) torch.set_rng_state(cpu_rng_state) for idx, cuda_rng_state in enumerate(cuda_rng_states): torch.cuda.set_rng_state(cuda_rng_state, idx) return (output.cpu() * 255).round().type(torch.uint8).numpy()
def create_plots(robot, obstacles, dist_est, checker): from matplotlib.cm import get_cmap cmaps = [get_cmap('Reds'), get_cmap('Blues')] if robot.dof > 2: fig = plt.figure(figsize=(3, 3)) ax = fig.add_subplot(111) #, projection='3d' elif robot.dof == 2: # Show C-space at the same time num_class = getattr(checker, 'num_class', 1) fig = plt.figure(figsize=(3 * (num_class), 3 * (num_class + 1))) plt.rcParams.update({ "text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"] }) gs = fig.add_gridspec(num_class + 1, num_class) ax = fig.add_subplot( gs[:-1, :] ) #sum([list(range(r*(num_class+1)+1, (r+1)*(num_class+1))) for r in range(num_class)], [])) #, projection='3d' cfg_path_plots = [] size = [400, 400] yy, xx = torch.meshgrid(torch.linspace(-np.pi, np.pi, size[0]), torch.linspace(-np.pi, np.pi, size[1])) grid_points = torch.stack([xx, yy], dim=2).reshape((-1, 2)) grid_points = grid_points.double( ) if checker.support_points.dtype == torch.float64 else grid_points score_spline = dist_est(grid_points).reshape(size + [num_class]) c_axes = [] with sns.axes_style('ticks'): for cat in range(num_class): c_ax = fig.add_subplot(gs[-1, cat]) # score_DiffCo = checker.score(grid_points).reshape(size) # score = (torch.sign(score_DiffCo)+1)/2*(score_spline-score_spline.min()) + (-torch.sign(score_DiffCo)+1)/2*(score_spline-score_spline.max()) score = score_spline[:, :, cat] color_mesh = c_ax.pcolormesh(xx, yy, score, cmap=cmaps[cat], vmin=-torch.abs(score).max(), vmax=torch.abs(score).max()) c_support_points = checker.support_points[ checker.gains[:, cat] != 0] c_ax.scatter(c_support_points[:, 0], c_support_points[:, 1], marker='.', c='black', s=1.5) c_ax.contour( xx, yy, score, levels=[0], linewidths=1, alpha=0.4, ) #-1.5, -0.75, 0, 0.3 # fig.colorbar(color_mesh, ax=c_ax) # sparse_score = score[5:-5:10, 5:-5:10] # score_grad_x = -ndimage.sobel(sparse_score.numpy(), axis=1) # score_grad_y = -ndimage.sobel(sparse_score.numpy(), axis=0) # score_grad = np.stack([score_grad_x, score_grad_y], axis=2) # score_grad /= np.linalg.norm(score_grad, axis=2, keepdims=True) # score_grad_x, score_grad_y = score_grad[:, :, 0], score_grad[:, :, 1] # c_ax.quiver(xx[5:-5:10, 5:-5:10], yy[5:-5:10, 5:-5:10], score_grad_x, score_grad_y, color='red', width=2e-3, headwidth=2, headlength=5) # cfg_point = Circle(collision_cfgs[0], radius=0.05, facecolor='orange', edgecolor='black', path_effects=[path_effects.withSimplePatchShadow()]) # c_ax.add_patch(cfg_point) for _ in range(4): cfg_path, = c_ax.plot([], [], '-o', c='olivedrab', markersize=3) cfg_path_plots.append(cfg_path) c_ax.set_aspect('equal', adjustable='box') # c_ax.axis('equal') c_ax.set_xlim(-np.pi, np.pi) c_ax.set_ylim(-np.pi, np.pi) c_ax.set_xticks([-np.pi, 0, np.pi]) c_ax.set_xticklabels(['$-\pi$', '$0$', '$\pi$']) c_ax.set_yticks([-np.pi, 0, np.pi]) c_ax.set_yticklabels(['$-\pi$', '$0$', '$\pi$']) # c_ax.tick_params(direction='in', reset=True) # c_ax.tick_params(which='both', direction='out', length=6, width=2, colors='r', # grid_color='r', grid_alpha=0.5) # c_ax.set_ticks('') # Plot ostacles # ax.axis('tight') ax.set_xlim(-8, 8) ax.set_ylim(-8, 8) ax.set_aspect('equal', adjustable='box') ax.set_xticks([-4, 0, 4]) ax.set_yticks([-4, 0, 4]) for obs in obstacles: cat = obs[3] if len(obs) >= 4 else 1 if obs[0] == 'circle': ax.add_patch( Circle(obs[1], obs[2], path_effects=[path_effects.withSimplePatchShadow()], color=cmaps[cat](0.5))) elif obs[0] == 'rect': ax.add_patch( Rectangle((obs[1][0] - float(obs[2][0]) / 2, obs[1][1] - float(obs[2][1]) / 2), obs[2][0], obs[2][1], path_effects=[path_effects.withSimplePatchShadow()], color=cmaps[cat](0.5))) # print((obs[1][0]-obs[2][0]/2, obs[1][1]-obs[2][1]/2)) # Placeholder of the robot plot trans = ax.transData.transform lw = ((trans((1, robot.link_width)) - trans( (0, 0))) * 72 / ax.figure.dpi)[1] link_plot, = ax.plot( [], [], color='silver', alpha=0.1, lw=lw, solid_capstyle='round', path_effects=[path_effects.SimpleLineShadow(), path_effects.Normal()]) joint_plot, = ax.plot([], [], 'o', color='tab:red', markersize=lw) eff_plot, = ax.plot([], [], 'o', color='black', markersize=lw) if robot.dof > 2: return fig, ax, link_plot, joint_plot, eff_plot elif robot.dof == 2: return fig, ax, link_plot, joint_plot, eff_plot, cfg_path_plots
def fcos_losses( self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, controllers_pred, focal_loss_alpha, focal_loss_gamma, iou_loss, matched_idxes, im_idxes, locations, ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = (sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg) reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] controllers_pred = controllers_pred[pos_inds] matched_idxes = matched_idxes[pos_inds] im_idxes = im_idxes[pos_inds] locations = locations[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg # for CondInst batch_ins = pos_inds.shape[0] N, C, h, w = self.masks.shape center_x = torch.clamp(locations[:, 0], min=0, max=w - 1).long() center_y = torch.clamp(locations[:, 1], min=0, max=h - 1).long() x_range = torch.linspace(-1, 1, w, device=self.masks.device) y_range = torch.linspace(-1, 1, h, device=self.masks.device) y, x = torch.meshgrid(y_range, x_range) x = x.unsqueeze(0).unsqueeze(0) y = y.unsqueeze(0).unsqueeze(0) grid = torch.cat([x, y], 1) offset_x = x_range[center_x].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) offset_y = y_range[center_y].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) offset_xy = torch.cat([offset_x, offset_y], 1) coords_feat = grid - offset_xy masks_feat = self.masks r_h = int(h * self.strides[0]) r_w = int(w * self.strides[0]) targets_masks = [ target_im.gt_masks.tensor for target_im in self.gt_instances ] masks_t = self.prepare_masks(h, w, r_h, r_w, targets_masks) mask_loss = masks_feat[0].new_tensor(0.0) batch_ins = im_idxes.shape[0] # for each image for i in range(N): inds = (im_idxes == i).nonzero().flatten() ins_num = inds.shape[0] if ins_num > 0: controllers = controllers_pred[inds] coord_feat = coords_feat[inds] mask_feat = masks_feat[None, i] mask_feat = torch.cat([mask_feat] * ins_num, dim=0) comb_feat = torch.cat((mask_feat, coord_feat), dim=1).view(1, -1, h, w) weight1, bias1, weight2, bias2, weight3, bias3 = torch.split( controllers, [80, 8, 64, 8, 8, 1], dim=1) bias1, bias2, bias3 = bias1.flatten(), bias2.flatten( ), bias3.flatten() weight1 = weight1.reshape(-1, 8, 10).reshape( -1, 10).unsqueeze(-1).unsqueeze(-1) weight2 = weight2.reshape(-1, 8, 8).reshape( -1, 8).unsqueeze(-1).unsqueeze(-1) weight3 = weight3.unsqueeze(-1).unsqueeze(-1) conv1 = F.conv2d(comb_feat, weight1, bias1, groups=ins_num).relu() conv2 = F.conv2d(conv1, weight2, bias2, groups=ins_num).relu() masks_per_image = F.conv2d(conv2, weight3, bias3, groups=ins_num) masks_per_image = aligned_bilinear( masks_per_image, self.strides[0])[0].sigmoid() for j in range(ins_num): ind = inds[j] mask_gt = masks_t[i][matched_idxes[ind]].float() mask_pred = masks_per_image[j] mask_loss += self.dice_loss(mask_pred, mask_gt) if batch_ins > 0: mask_loss = mask_loss / batch_ins losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss, "loss_mask": mask_loss, } return losses, {}
def combine(self, local_parameters, param_idx, user_idx): count = OrderedDict() if cfg['model_name'] == 'conv': output_weight_name = [ k for k in self.global_parameters.keys() if 'weight' in k ][-1] output_bias_name = [ k for k in self.global_parameters.keys() if 'bias' in k ][-1] for k, v in self.global_parameters.items(): parameter_type = k.split('.')[-1] count[k] = v.new_zeros(v.size(), dtype=torch.float32) tmp_v = v.new_zeros(v.size(), dtype=torch.float32) for m in range(len(local_parameters)): if 'weight' in parameter_type or 'bias' in parameter_type: if parameter_type == 'weight': if v.dim() > 1: if k == output_weight_name: label_split = self.label_split[user_idx[m]] param_idx[m][k] = list(param_idx[m][k]) param_idx[m][k][0] = param_idx[m][k][0][ label_split] tmp_v[torch.meshgrid( param_idx[m][k] )] += local_parameters[m][k][label_split] count[k][torch.meshgrid( param_idx[m][k])] += 1 else: tmp_v[torch.meshgrid( param_idx[m] [k])] += local_parameters[m][k] count[k][torch.meshgrid( param_idx[m][k])] += 1 else: tmp_v[param_idx[m] [k]] += local_parameters[m][k] count[k][param_idx[m][k]] += 1 else: if k == output_bias_name: label_split = self.label_split[user_idx[m]] param_idx[m][k] = param_idx[m][k][label_split] tmp_v[param_idx[m][k]] += local_parameters[m][ k][label_split] count[k][param_idx[m][k]] += 1 else: tmp_v[param_idx[m] [k]] += local_parameters[m][k] count[k][param_idx[m][k]] += 1 else: tmp_v += local_parameters[m][k] count[k] += 1 tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_( count[k][count[k] > 0]) v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype) elif 'resnet' in cfg['model_name']: for k, v in self.global_parameters.items(): parameter_type = k.split('.')[-1] count[k] = v.new_zeros(v.size(), dtype=torch.float32) tmp_v = v.new_zeros(v.size(), dtype=torch.float32) for m in range(len(local_parameters)): if 'weight' in parameter_type or 'bias' in parameter_type: if parameter_type == 'weight': if v.dim() > 1: if 'linear' in k: label_split = self.label_split[user_idx[m]] param_idx[m][k] = list(param_idx[m][k]) param_idx[m][k][0] = param_idx[m][k][0][ label_split] tmp_v[torch.meshgrid( param_idx[m][k] )] += local_parameters[m][k][label_split] count[k][torch.meshgrid( param_idx[m][k])] += 1 else: tmp_v[torch.meshgrid( param_idx[m] [k])] += local_parameters[m][k] count[k][torch.meshgrid( param_idx[m][k])] += 1 else: tmp_v[param_idx[m] [k]] += local_parameters[m][k] count[k][param_idx[m][k]] += 1 else: if 'linear' in k: label_split = self.label_split[user_idx[m]] param_idx[m][k] = param_idx[m][k][label_split] tmp_v[param_idx[m][k]] += local_parameters[m][ k][label_split] count[k][param_idx[m][k]] += 1 else: tmp_v[param_idx[m] [k]] += local_parameters[m][k] count[k][param_idx[m][k]] += 1 else: tmp_v += local_parameters[m][k] count[k] += 1 tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_( count[k][count[k] > 0]) v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype) elif cfg['model_name'] == 'transformer': for k, v in self.global_parameters.items(): parameter_type = k.split('.')[-1] count[k] = v.new_zeros(v.size(), dtype=torch.float32) tmp_v = v.new_zeros(v.size(), dtype=torch.float32) for m in range(len(local_parameters)): if 'weight' in parameter_type or 'bias' in parameter_type: if 'weight' in parameter_type: if v.dim() > 1: if k.split('.')[-2] == 'embedding': label_split = self.label_split[user_idx[m]] param_idx[m][k] = list(param_idx[m][k]) param_idx[m][k][0] = param_idx[m][k][0][ label_split] tmp_v[torch.meshgrid( param_idx[m][k] )] += local_parameters[m][k][label_split] count[k][torch.meshgrid( param_idx[m][k])] += 1 elif 'decoder' in k and 'linear2' in k: label_split = self.label_split[user_idx[m]] param_idx[m][k] = list(param_idx[m][k]) param_idx[m][k][0] = param_idx[m][k][0][ label_split] tmp_v[torch.meshgrid( param_idx[m][k] )] += local_parameters[m][k][label_split] count[k][torch.meshgrid( param_idx[m][k])] += 1 else: tmp_v[torch.meshgrid( param_idx[m] [k])] += local_parameters[m][k] count[k][torch.meshgrid( param_idx[m][k])] += 1 else: tmp_v[param_idx[m] [k]] += local_parameters[m][k] count[k][param_idx[m][k]] += 1 else: if 'decoder' in k and 'linear2' in k: label_split = self.label_split[user_idx[m]] param_idx[m][k] = param_idx[m][k][label_split] tmp_v[param_idx[m][k]] += local_parameters[m][ k][label_split] count[k][param_idx[m][k]] += 1 else: tmp_v[param_idx[m] [k]] += local_parameters[m][k] count[k][param_idx[m][k]] += 1 else: tmp_v += local_parameters[m][k] count[k] += 1 tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_( count[k][count[k] > 0]) v[count[k] > 0] = tmp_v[count[k] > 0].to(v.dtype) else: raise ValueError('Not valid model name') return
def meshgrid(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x = torch.meshgrid(a, b, indexing="ij") return x[0], x[1]
if (epoch == 0) and (testing == True): print('TARGETS: ', targets) model.train(False) encodings = model.test_encode(batch) print('ENCODINGS: ', encodings) # ============================================================================= # Setup the auxilliary tensors for the training run # ============================================================================= batch_size = 64 l = torch.arange(0 - 63 / 2., (63 / 2.) + 1) yyy, xxx, zzz = torch.meshgrid(l, l, l) xxx, yyy, zzz = xxx.repeat(batch_size, 1, 1, 1), yyy.repeat(batch_size, 1, 1, 1), zzz.repeat(batch_size, 1, 1, 1) xxx = xxx.to(device).to(torch.float) yyy = yyy.to(device).to(torch.float) zzz = zzz.to(device).to(torch.float) mask = torch.zeros((xxx.shape)).to(device).to(torch.float) thresh = torch.tensor([3]).to(device).to(torch.float) # ============================================================================= # Instantiate the model # =============================================================================
def proj_cost(settings, ref_feature, src_feature, level, ref_in, src_in, ref_ex, src_ex, depth_hypos): ## Calculate the cost volume for refined depth hypothesis selection batch, channels = ref_feature.shape[0], ref_feature.shape[1] num_depth = depth_hypos.shape[1] height, width = ref_feature.shape[2], ref_feature.shape[3] nSrc = len(src_feature) volume_sum = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1) volume_sq_sum = volume_sum.pow_(2) for src in range(settings.nsrc): with torch.no_grad(): src_proj = torch.matmul(src_in[:, src, :, :], src_ex[:, src, 0:3, :]) ref_proj = torch.matmul(ref_in, ref_ex[:, 0:3, :]) last = torch.tensor([[[0, 0, 0, 1.0]]]).repeat(len(src_in), 1, 1).cuda() src_proj = torch.cat((src_proj, last), 1) ref_proj = torch.cat((ref_proj, last), 1) proj = torch.matmul(src_proj, torch.inverse(ref_proj)) rot = proj[:, :3, :3] trans = proj[:, :3, 3:4] y, x = torch.meshgrid([ torch.arange(0, height, dtype=torch.float32, device=ref_feature.device), torch.arange(0, width, dtype=torch.float32, device=ref_feature.device) ]) y, x = y.contiguous(), x.contiguous() y, x = y.view(height * width), x.view(height * width) xyz = torch.stack((x, y, torch.ones_like(x))) xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) rot_xyz = torch.matmul(rot, xyz) rot_depth_xyz = rot_xyz.unsqueeze(2).repeat( 1, 1, num_depth, 1) * depth_hypos.view( batch, 1, num_depth, height * width) # [B, 3, Ndepth, H*W] proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1 proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) grid = proj_xy warped_src_fea = F.grid_sample(src_feature[src][level], grid.view(batch, num_depth * height, width, 2), mode='bilinear', padding_mode='zeros') warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width) volume_sum = volume_sum + warped_src_fea volume_sq_sum = volume_sq_sum + warped_src_fea.pow_(2) cost_volume = volume_sq_sum.div_(settings.nsrc + 1).sub_( volume_sum.div_(settings.nsrc + 1).pow_(2)) if settings.mode == "test": del volume_sum del volume_sq_sum torch.cuda.empty_cache() return cost_volume
def calDepthHypo(netArgs, ref_depths, ref_intrinsics, src_intrinsics, ref_extrinsics, src_extrinsics, depth_min, depth_max, level): ## Calculate depth hypothesis maps for refine steps nhypothesis_init = 48 d = 4 pixel_interval = 1 nBatch = ref_depths.shape[0] height = ref_depths.shape[1] width = ref_depths.shape[2] if netArgs.mode == "train": depth_interval = torch.tensor([6.8085] * nBatch).cuda( ) # Hard code the interval for training on DTU with 1 level of refinement. depth_hypos = ref_depths.unsqueeze(1).repeat(1, d * 2, 1, 1) for depth_level in range(-d, d): depth_hypos[:, depth_level + d, :, :] += (depth_level) * depth_interval[0] return depth_hypos with torch.no_grad(): ref_depths = ref_depths ref_intrinsics = ref_intrinsics src_intrinsics = src_intrinsics.squeeze(1) ref_extrinsics = ref_extrinsics src_extrinsics = src_extrinsics.squeeze(1) interval_maps = [] depth_hypos = ref_depths.unsqueeze(1).repeat(1, d * 2, 1, 1) for batch in range(nBatch): xx, yy = torch.meshgrid( [torch.arange(0, width), torch.arange(0, height)]) xxx = xx.reshape([-1]).cuda().float() yyy = yy.reshape([-1]).cuda().float() X = torch.stack([xxx, yyy, torch.ones_like(xxx)], dim=0) D1 = torch.transpose(ref_depths[batch, :, :], 0, 1).reshape( [-1] ) # Transpose before reshape to produce identical results to numpy and matlab version. # D2 = D1+1 X1 = X * D1 D1 = D1 + 1 X2 = X * D1 ray1 = torch.matmul(torch.inverse(ref_intrinsics[batch]), X1) ray2 = torch.matmul(torch.inverse(ref_intrinsics[batch]), X2) X1 = torch.cat([ray1, X[-1, :].unsqueeze(0)], dim=0) X1 = torch.matmul(torch.inverse(ref_extrinsics[batch]), X1) X2 = torch.cat([ray2, X[-1, :].unsqueeze(0)], dim=0) X2 = torch.matmul(torch.inverse(ref_extrinsics[batch]), X2) X1 = torch.matmul(src_extrinsics[batch][0], X1) X2 = torch.matmul(src_extrinsics[batch][0], X2) X1 = X1[:3] X1 = torch.matmul(src_intrinsics[batch][0], X1) X1_d = X1[2].clone() X1 /= X1_d X2 = X2[:3] X2 = torch.matmul(src_intrinsics[batch][0], X2) # X2_d = X2[2].clone() X2 = X2 / X2[2] k = (X2[1] - X1[1]) / (X2[0] - X1[0]) # b = X1[1]-k*X1[0] del X2 theta = torch.atan(k) X3 = X1 + torch.stack([ torch.cos(theta) * pixel_interval, torch.sin(theta) * pixel_interval, torch.zeros_like(X1[2, :]) ], dim=0) A = torch.matmul(ref_intrinsics[batch], ref_extrinsics[batch][:3, :3]) tmp = torch.matmul(src_intrinsics[batch][0], src_extrinsics[batch][0, :3, :3]) A = torch.matmul(A, torch.inverse(tmp)) tmp1 = X1_d * torch.matmul(A, X1) tmp2 = torch.matmul(A, X3) M1 = torch.cat( [X.t().unsqueeze(2), tmp2.t().unsqueeze(2)], axis=2)[:, 1:, :] M2 = tmp1.t()[:, 1:] ans = torch.matmul(torch.inverse(M1), M2.unsqueeze(2)) delta_d = ans[:, 0, 0] interval_maps = torch.abs(delta_d).mean().repeat( ref_depths.shape[2], ref_depths.shape[1]).t() for depth_level in range(-d, d): depth_hypos[batch, depth_level + d, :, :] += depth_level * interval_maps # print("Calculated:") # print(interval_maps[0,0]) # pdb.set_trace() return depth_hypos # Return the depth hypothesis map from statistical interval setting.
def grid(W): y, x = torch.meshgrid([torch.arange(0., W).type(dtype) / W] * 2) return torch.stack((x, y), dim=2).view(-1, 2)
def _train_on_epoch(self, model, optimizer): model.train() data_loader = self.data_loaders[0] running_loss_XYZ = 0.0 running_loss_mask = 0.0 running_loss = 0.0 for self.iteration, batch in enumerate(data_loader, self.iteration): input_images, depthGT, maskGT = utils.unpack_batch_fixed(batch, self.cfg.device) # ------ define ground truth------ XGT, YGT = torch.meshgrid([ torch.arange(self.cfg.outH), # [H,W] torch.arange(self.cfg.outW)]) # [H,W] XGT, YGT = XGT.float(), YGT.float() XYGT = torch.cat([ XGT.repeat([self.cfg.outViewN, 1, 1]), YGT.repeat([self.cfg.outViewN, 1, 1])], dim=0) #[2V,H,W] XYGT = XYGT.unsqueeze(dim=0).to(self.cfg.device) # [1,2V,H,W] with torch.set_grad_enabled(True): optimizer.zero_grad() XYZ, maskLogit = model(input_images) XY = XYZ[:, :self.cfg.outViewN * 2, :, :] depth = XYZ[:, self.cfg.outViewN * 2:self.cfg.outViewN * 3, :, :] mask = (maskLogit > 0).byte() # ------ Compute loss ------ loss_XYZ = self.l1(XY, XYGT) loss_XYZ += self.l1(depth.masked_select(mask), depthGT.masked_select(mask)) loss_mask = self.sigmoid_bce(maskLogit, maskGT) loss = loss_mask + self.cfg.lambdaDepth * loss_XYZ # ------ Update weights ------ loss.backward() # True Weight decay if self.cfg.trueWD is not None: for group in optimizer.param_groups: for param in group['params']: param.data.add_( -self.cfg.trueWD * group['lr'], param.data) optimizer.step() if self.on_after_batch is not None: if self.cfg.lrSched.lower() in "cyclical": self.on_after_batch(self.iteration) else: self.on_after_batch(self.epoch) running_loss_XYZ += loss_XYZ.item() * input_images.size(0) running_loss_mask += loss_mask.item() * input_images.size(0) running_loss += loss.item() * input_images.size(0) epoch_loss_XYZ = running_loss_XYZ / len(data_loader.dataset) epoch_loss_mask = running_loss_mask / len(data_loader.dataset) epoch_loss = running_loss / len(data_loader.dataset) print(f"\tTrain loss: {epoch_loss}") return {"epoch_loss_XYZ": epoch_loss_XYZ, "epoch_loss_mask": epoch_loss_mask, "epoch_loss": epoch_loss, }
def forward(self, x, y, loss_mask=None): # outputs, predictions = x # ground truths = y self.image_dimensions = x.shape _, _, h, w, d = self.image_dimensions # h x w x d --> can't be in the init part because x is defined only in the forward. (maybe put a self there, and call x from init??) coordinates_map = torch.meshgrid( [torch.arange(h), torch.arange(w), torch.arange(d)]) self.coords = coordinates_map sour, targ, dy_gt, dx_gt, dd_gt = self.GT_sr self.compute_centroids(x) # self.compute_colors(data, x) --> data must be accessed from here. Not done. sum_err = 0 # for each prior relation, compute distance/color in the prediction and compare it with the ground truth. for rel in range(len(sour)): sour0 = int(sour[rel]) targ0 = int(targ[rel]) dy = self.centroids_y[sour0] - self.centroids_y[targ0] dx = self.centroids_x[sour0] - self.centroids_x[targ0] dd = self.centroids_d[sour0] - self.centroids_d[targ0] """print('rel is: ', rel, ' --------- source: ', sour0, '------- targ: ', targ0, '-----------') print('dy is:', dy) print('dx is:', dx) print('dd is:', dd) #dc = self.colors[:, i[rel]] - self.colors[:, j[rel]] print('dy_gt is:', dy_gt[rel]) print('dx_gt is:', dx_gt[rel]) print('dd_gt is:', dd_gt[rel])""" diff_y = (dy - dy_gt[rel]) / self.image_dimensions[3] diff_x = (dx - dx_gt[rel]) / self.image_dimensions[2] diff_d = (dd - dd_gt[rel]) / self.image_dimensions[4] #diff_c = (dc - dc_gt[rel]) #print('diff_y is:', diff_y) #print('diff_x is:', diff_x) #print('diff_d is:', diff_d) dy_err = torch.mean(torch.square(weird_to_num(diff_y, replace=0.0))) dx_err = torch.mean(torch.square(weird_to_num(diff_x, replace=0.0))) dd_err = torch.mean(torch.square(weird_to_num(diff_d, replace=0.0))) #dc_err = torch.mean(torch.square(weird_to_num(diff_c, replace=0.0))) #print('dy_err is:', dy_err) #print('dx_err is:', dx_err) #print('dd_err is:', dd_err) sum_err += dy_err + dx_err + dd_err # + dc_err return sum_err
def graph_layer(self, encoded_seq, info, word_sec, section, positions): """ Graph Layer -> Construct a document-level graph The graph edges hold representations for the connections between the nodes. Args: encoded_seq: Encoded sequence, shape (sentences, words, dimension) info: (Tensor, 5 columns) entity_id, entity_type, start_wid, end_wid, sentence_id word_sec: (Tensor) number of words per sentence section: (Tensor <B, 3>) #entities/#mentions/#sentences per batch positions: distances between nodes (only M-M and S-S) Returns: (Tensor) graph, (Tensor) tensor_mapping, (Tensors) indices, (Tensor) node information """ # SENTENCE NODES sentences = torch.mean(encoded_seq, dim=1) # sentence nodes (avg of sentence words) # MENTION & ENTITY NODES temp_ = torch.arange(word_sec.max()).unsqueeze(0).repeat( sentences.size(0), 1).to(self.device) remove_pad = (temp_ < word_sec.unsqueeze(1)) mentions = self.merge_tokens(info, encoded_seq, remove_pad) # mention nodes entities = self.merge_mentions(info, mentions) # entity nodes # all nodes in order: entities - mentions - sentences nodes = torch.cat((entities, mentions, sentences), dim=0) # e + m + s (all) nodes_info = self.node_info( section, info) # info/node: node type | semantic type | sentence ID if self.types: # + node types nodes = torch.cat((nodes, self.type_embed(nodes_info[:, 0])), dim=1) # re-order nodes per document (batch) nodes = self.rearrange_nodes(nodes, section) nodes = self.split_n_pad(nodes, section, pad=0) nodes_info = self.rearrange_nodes(nodes_info, section) nodes_info = self.split_n_pad(nodes_info, section, pad=-1) # create initial edges (concat node representations) r_idx, c_idx = torch.meshgrid( torch.arange(nodes.size(1)).to(self.device), torch.arange(nodes.size(1)).to(self.device)) graph = torch.cat((nodes[:, r_idx], nodes[:, c_idx]), dim=3) r_id, c_id = nodes_info[..., 0][:, r_idx], nodes_info[ ..., 0][:, c_idx] # node type indicators # pair masks pid = self.pair_ids(r_id, c_id) # Linear reduction layers reduced_graph = torch.where( pid['MS'].unsqueeze(-1), self.reduce['MS'](graph), torch.zeros(graph.size()[:-1] + (self.out_dim, )).to(self.device)) reduced_graph = torch.where(pid['ME'].unsqueeze(-1), self.reduce['ME'](graph), reduced_graph) reduced_graph = torch.where(pid['ES'].unsqueeze(-1), self.reduce['ES'](graph), reduced_graph) if self.dist: dist_vec = self.dist_embed(positions) # distances reduced_graph = torch.where( pid['SS'].unsqueeze(-1), self.reduce['SS'](torch.cat( (graph, dist_vec), dim=3)), reduced_graph) else: reduced_graph = torch.where(pid['SS'].unsqueeze(-1), self.reduce['SS'](graph), reduced_graph) if self.context and self.dist: m_cntx = self.attention(mentions, encoded_seq[info[:, 4]], info, word_sec) m_cntx = self.prepare_mention_context(m_cntx, section, r_idx, c_idx, encoded_seq[info[:, 4]], pid, nodes_info) reduced_graph = torch.where( pid['MM'].unsqueeze(-1), self.reduce['MM'](torch.cat( (graph, dist_vec, m_cntx), dim=3)), reduced_graph) elif self.context: m_cntx = self.attention(mentions, encoded_seq[info[:, 4]], info, word_sec) m_cntx = self.prepare_mention_context(m_cntx, section, r_idx, c_idx, encoded_seq[info[:, 4]], pid, nodes_info) reduced_graph = torch.where( pid['MM'].unsqueeze(-1), self.reduce['MM'](torch.cat( (graph, m_cntx), dim=3)), reduced_graph) elif self.dist: reduced_graph = torch.where( pid['MM'].unsqueeze(-1), self.reduce['MM'](torch.cat( (graph, dist_vec), dim=3)), reduced_graph) else: reduced_graph = torch.where(pid['MM'].unsqueeze(-1), self.reduce['MM'](graph), reduced_graph) if self.ee: reduced_graph = torch.where(pid['EE'].unsqueeze(-1), self.reduce['EE'](graph), reduced_graph) mask = self.get_nodes_mask(section.sum(dim=1)) return reduced_graph, (r_idx, c_idx), nodes_info, mask
def generate_region_meshgrid(num_regions,region_size,region_offsets): hh,ww = torch.meshgrid(torch.arange(0,num_regions[0]),torch.arange(0,num_regions[1])) hh = (hh*region_size[0]+region_offsets[0]).cuda().float() ww = (ww*region_size[1]+region_offsets[1]).cuda().float() return hh,ww
def eval(args, unet, imnet, eval_loader, epoch, global_step, device, logger, writer, optimizer, pde_layer): """Eval function. Used for evaluating entire slices and comparing to GT.""" unet.eval() imnet.eval() phys_channels = ["p", "b", "u", "w"] phys2id = dict(zip(phys_channels, range(len(phys_channels)))) xmin = torch.zeros(3, dtype=torch.float32).to(device) xmax = torch.ones(3, dtype=torch.float32).to(device) for data_tensors in eval_loader: # only need the first batch break # send tensors to device data_tensors = [t.to(device) for t in data_tensors] hres_grid, lres_grid, _, _ = data_tensors latent_grid = unet(lres_grid) # [batch, C, T, Z, X] nb, nc, nt, nz, nx = hres_grid.shape # permute such that C is the last channel for local implicit grid query latent_grid = latent_grid.permute(0, 2, 3, 4, 1) # [batch, T, Z, X, C] # define lambda function for pde_layer fwd_fn = lambda points: query_local_implicit_grid(imnet, latent_grid, points, xmin, xmax) # update pde layer and compute predicted values + pde residues pde_layer.update_forward_method(fwd_fn) # layout query points for the desired slices eps = 1e-6 t_seq = torch.linspace(eps, 1 - eps, nt)[::int(nt / 8)] # temporal sequences z_seq = torch.linspace(eps, 1 - eps, nz) # z sequences x_seq = torch.linspace(eps, 1 - eps, nx) # x sequences query_coord = torch.stack(torch.meshgrid(t_seq, z_seq, x_seq), axis=-1) # [nt, nz, nx, 3] query_coord = query_coord.reshape([-1, 3]).to(device) # [nt*nz*nx, 3] n_query = query_coord.shape[0] res_dict = defaultdict(list) n_iters = int(np.ceil(n_query / args.pseudo_batch_size)) for idx in range(n_iters): sid = idx * args.pseudo_batch_size eid = min(sid + args.pseudo_batch_size, n_query) query_coord_batch = query_coord[sid:eid] query_coord_batch = query_coord_batch[None].expand( *(nb, eid - sid, 3)) # [nb, eid-sid, 3] pred_value, residue_dict = pde_layer(query_coord_batch, return_residue=True) pred_value = pred_value.detach() for key in residue_dict.keys(): residue_dict[key] = residue_dict[key].detach() for name, chan_id in zip(phys_channels, range(4)): res_dict[name].append(pred_value[..., chan_id]) # [b, pb] for name, val in residue_dict.items(): res_dict[name].append(val[..., 0]) # [b, pb] for key in res_dict.keys(): res_dict[key] = (torch.cat(res_dict[key], axis=1).reshape( [nb, len(t_seq), len(z_seq), len(x_seq)])) # log the imgs sample-by-sample for samp_id in range(nb): for key in res_dict.keys(): field = res_dict[key][samp_id] # [nt, nz, nx] # add predicted slices images = utils.batch_colorize_scalar_tensors( field) # [nt, nz, nx, 3] writer.add_images('sample_{}/{}/predicted'.format(samp_id, key), images, dataformats='NHWC', global_step=int(global_step)) # add ground truth slices (only for phys channels) if key in phys_channels: gt_fields = hres_grid[samp_id, phys2id[key], ::int(nt / 8)] # [nt, nz, nx] gt_images = utils.batch_colorize_scalar_tensors( gt_fields) # [nt, nz, nx, 3] writer.add_images('sample_{}/{}/ground_truth'.format( samp_id, key), gt_images, dataformats='NHWC', global_step=int(global_step))
def render_singletask_gp( ax: [plt.Axes, Axes3D, Sequence[plt.Axes]], data_x: to.Tensor, data_y: to.Tensor, idcs_sel: list, data_x_min: to.Tensor = None, data_x_max: to.Tensor = None, x_label: str = '', y_label: str = '', z_label: str = '', min_gp_obsnoise: float = None, resolution: int = 201, num_stds: int = 2, alpha: float = 0.3, color: chr = None, curve_label: str = 'mean', heatmap_cmap: colors.Colormap = None, show_legend_posterior: bool = True, show_legend_std: bool = False, show_legend_data: bool = True, legend_data_cmap: colors.Colormap = None, colorbar_label: str = None, title: str = None, render3D: bool = True, ) -> plt.Figure: """ Fit the GP posterior to the input data and plot the mean and std as well as the data points. There are 3 options: 1D plot (infered by data dimensions), 2D plot .. note:: If you want to have a tight layout, it is best to pass axes of a figure with `tight_layout=True` or `constrained_layout=True`. :param ax: axis of the figure to plot on, only in case of a 2-dim heat map plot provide 2 axis :param data_x: data to plot on the x-axis :param data_y: data to process and plot on the y-axis :param idcs_sel: selected indices of the input data :param data_x_min: explicit minimum value for the evaluation grid, by default this value is extracted from `data_x` :param data_x_max: explicit maximum value for the evaluation grid, by default this value is extracted from `data_x` :param x_label: label for x-axis :param y_label: label for y-axis :param z_label: label for z-axis (3D plot only) :param min_gp_obsnoise: set a minimal noise value (normalized) for the GP, if `None` the GP has no measurement noise :param resolution: number of samples for the input (corresponds to x-axis resolution of the plot) :param num_stds: number of standard deviations to plot around the mean :param alpha: transparency (alpha-value) for the std area :param color: color (e.g. 'k' for black), `None` invokes the default behavior :param curve_label: label for the mean curve (1D plot only) :param heatmap_cmap: color map forwarded to `render_heatmap()` (2D plot only), `None` to use Pyrado's default :param show_legend_posterior: flag if the legend entry for the posterior should be printed (affects mean and std) :param show_legend_std: flag if a legend entry for the std area should be printed :param show_legend_data: flag if a legend entry for the individual data points should be printed :param legend_data_cmap: color map for the sampled points, default is 'binary' :param colorbar_label: label for the color bar (2D plot only) :param title: title displayed above the figure, set to `None` to suppress the title :param render3D: use 3D rendering if possible :return: handle to the resulting figure """ if data_x.ndim != 2: raise pyrado.ShapeErr( msg= "The GP's input data needs to be of shape num_samples x dim_input!" ) data_x = data_x[:, idcs_sel] # forget the rest dim_x = data_x.shape[1] # samples are along axis 0 if data_y.ndim != 2: raise pyrado.ShapeErr(given=data_y, expected_match=to.Size([data_x.shape[0], 1])) if legend_data_cmap is None: legend_data_cmap = plt.get_cmap('binary') # Project to normalized input and standardized output if data_x_min is None or data_x_max is None: data_x_min, data_x_max = to.min(data_x, dim=0)[0], to.max(data_x, dim=0)[0] data_y_mean, data_y_std = to.mean(data_y, dim=0), to.std(data_y, dim=0) data_x = (data_x - data_x_min) / (data_x_max - data_x_min) data_y = (data_y - data_y_mean) / data_y_std # Create and fit the GP model gp = SingleTaskGP(data_x, data_y) if min_gp_obsnoise is not None: gp.likelihood.noise_covar.register_constraint( 'raw_noise', GreaterThan(min_gp_obsnoise)) mll = ExactMarginalLogLikelihood(gp.likelihood, gp) mll.train() fit_gpytorch_model(mll) print_cbt('Fitted the SingleTaskGP.', 'g') argmax_pmean_norm, argmax_pmean_val_stdzed = optimize_acqf( acq_function=PosteriorMean(gp), bounds=to.stack([to.zeros(dim_x), to.ones(dim_x)]), q=1, num_restarts=500, raw_samples=1000) # Project back argmax_posterior = argmax_pmean_norm * (data_x_max - data_x_min) + data_x_min argmax_pmean_val = argmax_pmean_val_stdzed * data_y_std + data_y_mean print_cbt( f'Converged to argmax of the posterior mean: {argmax_posterior.numpy()}', 'g') mll.eval() gp.eval() if dim_x == 1: # Evaluation grid x_grid = np.linspace(min(data_x), max(data_x), resolution, endpoint=True).flatten() x_grid = to.from_numpy(x_grid) # Mean and standard deviation of the surrogate model posterior = gp.posterior(x_grid) mean = posterior.mean.detach().flatten() std = to.sqrt(posterior.variance.detach()).flatten() # Project back from normalized input and standardized output x_grid = x_grid * (data_x_max - data_x_min) + data_x_min data_x = data_x * (data_x_max - data_x_min) + data_x_min data_y = data_y * data_y_std + data_y_mean mean = mean * data_y_std + data_y_mean std *= data_y_std # double-checked with posterior.mvn.confidence_region() # Plot the curve plt.fill_between(x_grid.numpy(), mean.numpy() - num_stds * std.numpy(), mean.numpy() + num_stds * std.numpy(), alpha=alpha, color=color) ax.plot(x_grid.numpy(), mean.numpy(), color=color) # Plot the queried data points scat_plot = ax.scatter(data_x.numpy().flatten(), data_y.numpy().flatten(), marker='o', c=np.arange(data_x.shape[0], dtype=np.int), cmap=legend_data_cmap) if show_legend_data: scat_legend = ax.legend( *scat_plot.legend_elements(fmt='{x:.0f}'), # integer formatter bbox_to_anchor=(0., 1.1, 1., -0.1), title='query points', ncol=data_x.shape[0], loc='upper center', mode='expand', borderaxespad=0., handletextpad=-0.5) ax.add_artist(scat_legend) # Increase vertical space between subplots when printing the data labels # plt.tight_layout(pad=2.) # ignore argument # plt.subplots_adjust(hspace=0.6) # Plot the argmax of the posterior mean # ax.scatter(argmax_posterior.item(), argmax_pmean_val, c='darkorange', marker='o', s=60, label='argmax') ax.axvline(argmax_posterior.item(), c='darkorange', lw=1.5, label='argmax') if show_legend_posterior: ax.add_artist(ax.legend(loc='lower right')) elif dim_x == 2: # Create mesh grid matrices from x and y vectors # x0_grid = to.linspace(min(data_x[:, 0]), max(data_x[:, 0]), resolution) # x1_grid = to.linspace(min(data_x[:, 1]), max(data_x[:, 1]), resolution) x0_grid = to.linspace(0, 1, resolution) x1_grid = to.linspace(0, 1, resolution) x0_mesh, x1_mesh = to.meshgrid([x0_grid, x1_grid]) x0_mesh, x1_mesh = x0_mesh.t(), x1_mesh.t( ) # transpose not necessary but makes identical mesh as np.meshgrid # Mean and standard deviation of the surrogate model x_test = to.stack([ x0_mesh.reshape(resolution**2, 1), x1_mesh.reshape(resolution**2, 1) ], -1).squeeze(1) posterior = gp.posterior( x_test) # identical to gp.likelihood(gp(x_test)) mean = posterior.mean.detach().reshape(resolution, resolution) std = to.sqrt(posterior.variance.detach()).reshape( resolution, resolution) # Project back from normalized input and standardized output data_x = data_x * (data_x_max - data_x_min) + data_x_min data_y = data_y * data_y_std + data_y_mean mean_raw = mean * data_y_std + data_y_mean std_raw = std * data_y_std if render3D: # Project back from normalized input and standardized output (custom for 3D) x0_mesh = x0_mesh * (data_x_max[0] - data_x_min[0]) + data_x_min[0] x1_mesh = x1_mesh * (data_x_max[1] - data_x_min[1]) + data_x_min[1] lower = mean_raw - num_stds * std_raw upper = mean_raw + num_stds * std_raw # Plot a 2D surface in 3D ax.plot_surface(x0_mesh.numpy(), x1_mesh.numpy(), mean_raw.numpy()) ax.plot_surface(x0_mesh.numpy(), x1_mesh.numpy(), lower.numpy(), color='r', alpha=alpha) ax.plot_surface(x0_mesh.numpy(), x1_mesh.numpy(), upper.numpy(), color='r', alpha=alpha) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_zlabel(z_label) # Plot the queried data points scat_plot = ax.scatter(data_x[:, 0].numpy(), data_x[:, 1].numpy(), data_y.numpy(), marker='o', c=np.arange(data_x.shape[0], dtype=np.int), cmap=legend_data_cmap) if show_legend_data: scat_legend = ax.legend( *scat_plot.legend_elements( fmt='{x:.0f}'), # integer formatter bbox_to_anchor=(0.05, 1.1, 0.95, -0.1), loc='upper center', ncol=data_x.shape[0], mode='expand', borderaxespad=0., handletextpad=-0.5) ax.add_artist(scat_legend) # Plot the argmax of the posterior mean x, y = argmax_posterior[0, 0], argmax_posterior[0, 1] ax.scatter(x, y, argmax_pmean_val, c='darkorange', marker='*', s=60) # ax.plot((x, x), (y, y), (data_y.min(), data_y.max()), c='k', ls='--', lw=1.5) else: if not len(ax) == 4: raise pyrado.ShapeErr( msg='Provide 4 axes! 2 heat maps and 2 color bars.') # Project back normalized input and standardized output (custom for 2D) x0_grid_raw = x0_grid * (data_x_max[0] - data_x_min[0]) + data_x_min[0] x1_grid_raw = x1_grid * (data_x_max[1] - data_x_min[1]) + data_x_min[1] # Plot a 2D image df_mean = pd.DataFrame(mean_raw.numpy(), columns=x0_grid_raw.numpy(), index=x1_grid_raw.numpy()) render_heatmap(df_mean, ax_hm=ax[0], ax_cb=ax[1], x_label=x_label, y_label=y_label, annotate=False, fig_canvas_title='Returns', tick_label_prec=2, add_sep_colorbar=True, cmap=heatmap_cmap, colorbar_label=colorbar_label, num_major_ticks_hm=3, num_major_ticks_cb=2, colorbar_orientation='horizontal') df_std = pd.DataFrame(std_raw.numpy(), columns=x0_grid_raw.numpy(), index=x1_grid_raw.numpy()) render_heatmap( df_std, ax_hm=ax[2], ax_cb=ax[3], x_label=x_label, y_label=y_label, annotate=False, fig_canvas_title='Standard Deviations', tick_label_prec=2, add_sep_colorbar=True, cmap=heatmap_cmap, colorbar_label=colorbar_label, num_major_ticks_hm=3, num_major_ticks_cb=2, colorbar_orientation='horizontal', norm=colors.Normalize()) # explicitly instantiate a new norm # Plot the queried data points for i in [0, 2]: scat_plot = ax[i].scatter(data_x[:, 0].numpy(), data_x[:, 1].numpy(), marker='o', s=15, c=np.arange(data_x.shape[0], dtype=np.int), cmap=legend_data_cmap) if show_legend_data: scat_legend = ax[i].legend( *scat_plot.legend_elements( fmt='{x:.0f}'), # integer formatter bbox_to_anchor=(0., 1.1, 1., 0.05), loc='upper center', ncol=data_x.shape[0], mode='expand', borderaxespad=0., handletextpad=-0.5) ax[i].add_artist(scat_legend) # Plot the argmax of the posterior mean ax[0].scatter(argmax_posterior[0, 0], argmax_posterior[0, 1], c='darkorange', marker='*', s=60) # steelblue ax[2].scatter(argmax_posterior[0, 0], argmax_posterior[0, 1], c='darkorange', marker='*', s=60) # steelblue # ax[0].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5) # ax[0].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5) # ax[2].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5) # ax[2].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5) else: raise pyrado.ValueErr(msg='Can only plot 1-dim or 2-dim data!') return plt.gcf()
def forward(self, feature_pyramid1, feature_pyramid2, actions=None): """Run the model.""" context = None flow = None flow_up = None context_up = None flows = [] # make sure that actions are provided iff network is configured for action use if actions is not None: assert self._action_channels == actions.shape[1] else: assert self._action_channels is None # Go top down through the levels to the second to last one to estimate flow. for level, (features1, features2) in reversed( list(enumerate(zip(feature_pyramid1, feature_pyramid2)))[1:]): # init flows with zeros for coarsest level if needed if self._shared_flow_decoder and flow_up is None: batch_size, height, width, _ = features1.shape.as_list() flow_up = torch.zeros([batch_size, height, width, 2]).to(gpu_utils.device) if self._num_context_up_channels: num_channels = int(self._num_context_up_channels * self._channel_multiplier) context_up = torch.zeros( [batch_size, height, width, num_channels]).to(gpu_utils.device) # Warp features2 with upsampled flow from higher level. if flow_up is None or not self._use_feature_warp: warped2 = features2 else: warp_up = uflow_utils.flow_to_warp(flow_up) warped2 = uflow_utils.resample(features2, warp_up) # Compute cost volume by comparing features1 and warped features2. features1_normalized, warped2_normalized = uflow_utils.normalize_features( [features1, warped2], normalize=self._normalize_before_cost_volume, center=self._normalize_before_cost_volume, moments_across_channels=True, moments_across_images=True) if self._use_cost_volume: cost_volume = uflow_utils.compute_cost_volume( features1_normalized, warped2_normalized, max_displacement=4) else: concat_features = torch.cat( [features1_normalized, warped2_normalized], dim=1) cost_volume = self._cost_volume_surrogate_convs[level]( concat_features) cost_volume = func.leaky_relu( cost_volume, negative_slope=self._leaky_relu_alpha) if self._shared_flow_decoder: # This will ensure to work for arbitrary feature sizes per level. conv_1x1 = self._1x1_shared_decoder[level] features1 = conv_1x1(features1) # Compute context and flow from previous flow, cost volume, and features1. if flow_up is None: x_in = torch.cat([cost_volume, features1], dim=1) else: if context_up is None: x_in = torch.cat([flow_up, cost_volume, features1], dim=1) else: x_in = torch.cat( [context_up, flow_up, cost_volume, features1], dim=1) if self._action_channels is not None and self._action_channels > 0: # convert every entry in actions to a channel filled with this value and attach it to flow input B, _, H, W = features1.shape # additionally append xy position augmentation action_tensor = actions[:, :, None, None].repeat(1, 1, H, W) gy, gx = torch.meshgrid( [torch.arange(H).float(), torch.arange(W).float()]) gx = gx.repeat(B, 1, 1, 1).to(gpu_utils.device) gy = gy.repeat(B, 1, 1, 1).to(gpu_utils.device) x_in = torch.cat([x_in, action_tensor, gx, gy], dim=1) # Use dense-net connections. x_out = None if self._shared_flow_decoder: # reuse the same flow decoder on all levels flow_layers = self._flow_layers else: flow_layers = self._flow_layers[level] for layer in flow_layers[:-1]: x_out = layer(x_in) x_in = torch.cat([x_in, x_out], dim=1) context = x_out flow = flow_layers[-1](context) # dropout full layer if self.training and self._drop_out_rate: maybe_dropout = (torch.rand([]) > self._drop_out_rate).type( torch.get_default_dtype()) # note that operation must not be inplace, otherwise autograd will fail pathetically context = context * maybe_dropout flow = flow * maybe_dropout if flow_up is not None and self._accumulate_flow: flow += flow_up # Upsample flow for the next lower level. flow_up = uflow_utils.upsample(flow, is_flow=True) if self._num_context_up_channels: context_up = self._context_up_layers[level](context) # Append results to list. flows.insert(0, flow) # Refine flow at level 1. # refinement = self._refine_model(torch.cat([context, flow], dim=1)) refinement = torch.cat([context, flow], dim=1) for layer in self._refine_model: refinement = layer(refinement) # dropout refinement if self.training and self._drop_out_rate: maybe_dropout = (torch.rand([]) > self._drop_out_rate).type( torch.get_default_dtype()) # note that operation must not be inplace, otherwise autograd will fail pathetically refinement = refinement * maybe_dropout refined_flow = flow + refinement flows[0] = refined_flow flows.insert(0, uflow_utils.upsample(flows[0], is_flow=True)) flows.insert(0, uflow_utils.upsample(flows[0], is_flow=True)) return flows
def coords_grid(batch, ht, wd): coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1)
def main(args): pyro.set_rng_seed(args.rng_seed) fig = plt.figure(figsize=(8, 16), constrained_layout=True) gs = GridSpec(4, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[2, 0]) ax5 = fig.add_subplot(gs[3, 0]) ax6 = fig.add_subplot(gs[1, 1]) ax7 = fig.add_subplot(gs[2, 1]) ax8 = fig.add_subplot(gs[3, 1]) xlim = tuple(int(x) for x in args.x_lim.strip().split(',')) ylim = tuple(int(x) for x in args.y_lim.strip().split(',')) assert len(xlim) == 2 assert len(ylim) == 2 # 1. Plot samples drawn from BananaShaped distribution x1, x2 = torch.meshgrid( [torch.linspace(*xlim, 100), torch.linspace(*ylim, 100)]) d = BananaShaped(args.param_a, args.param_b) p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1))) ax1.contourf( x1, x2, p, cmap='OrRd', ) ax1.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='BananaShaped distribution: \nlog density') # 2. Run vanilla HMC logging.info('\nDrawing samples using vanilla HMC ...') mcmc = run_hmc(args, model) vanilla_samples = mcmc.get_samples()['x'].cpu().numpy() ax2.contourf(x1, x2, p, cmap='OrRd') ax2.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(vanilla HMC)') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2) # 3(a). Fit a diagonal normal autoguide logging.info('\nFitting a DiagNormal autoguide ...') guide = AutoDiagonalNormal(model, init_scale=0.05) fit_guide(guide, args) with pyro.plate('N', args.num_samples): guide_samples = guide()['x'].detach().cpu().numpy() ax3.contourf(x1, x2, p, cmap='OrRd') ax3.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(DiagNormal autoguide)') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3) # 3(b). Draw samples using NeuTra HMC logging.info( '\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...') neutra = NeuTraReparam(guide.requires_grad_(False)) neutra_model = poutine.reparam(model, config=lambda _: neutra) mcmc = run_hmc(args, neutra_model) zs = mcmc.get_samples()['x_shared_latent'] sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4) ax4.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)') samples = neutra.transform_sample(zs) samples = samples['x'].cpu().numpy() ax5.contourf(x1, x2, p, cmap='OrRd') ax5.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)') sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5) # 4(a). Fit a BNAF autoguide logging.info('\nFitting a BNAF autoguide ...') guide = AutoNormalizingFlow( model, partial(iterated, args.num_flows, block_autoregressive)) fit_guide(guide, args) with pyro.plate('N', args.num_samples): guide_samples = guide()['x'].detach().cpu().numpy() ax6.contourf(x1, x2, p, cmap='OrRd') ax6.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(BNAF autoguide)') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6) # 4(b). Draw samples using NeuTra HMC logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...') neutra = NeuTraReparam(guide.requires_grad_(False)) neutra_model = poutine.reparam(model, config=lambda _: neutra) mcmc = run_hmc(args, neutra_model) zs = mcmc.get_samples()['x_shared_latent'] sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7) ax7.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(BNAF + NeuTra HMC)') samples = neutra.transform_sample(zs) samples = samples['x'].cpu().numpy() ax8.contourf(x1, x2, p, cmap='OrRd') ax8.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior (transformed) \n(BNAF + NeuTra HMC)') sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8) plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))
def construct_image(c): x, y = torch.meshgrid(torch.linspace(- 5, 5, 64), - torch.linspace(- 5, 5, 64)) b = torch.stack([x ** i * y ** j for i in range(5) for j in range(5 - i)]) return torch.einsum('ij,jkl->ikl', c, b).view(- 1, 64, 64)
def get_grid(size): grids = torch.meshgrid([torch.arange(s) for s in size]) return torch.stack(grids, 0).float()
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" device = [0, 1] from model import CreateModel import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as Data from tqdm import tqdm torch.manual_seed(0) # Data tick = torch.arange(0., 256.) / 255. #0~255 R, G, B = torch.meshgrid(tick, tick, tick) # gamma 1.1 X = torch.stack( (torch.reshape(R, (-1, 1)), torch.reshape(G, (-1, 1)), torch.reshape(B, (-1, 1))), dim=1) X = torch.unsqueeze(X, 2) Y = torch.pow(X, 1 / 1.1) X -= 0.5 Y -= 0.5 dataset = Data.TensorDataset(X, Y) data_iter = Data.DataLoader(dataset,
def from_matrix( ranges_i, ranges_j, keep ) : r"""Turns a boolean matrix into a KeOps-friendly **ranges** argument. This routine is a helper for the **block-sparse** reduction mode of KeOps, allowing you to turn clustering information (**ranges_i**, **ranges_j**) and a cluster-to-cluster boolean mask (**keep**) into integer tensors of indices that can be used to schedule the KeOps routines. Suppose that you're working with variables :math:`x_i` (:math:`i \in [0,10^6)`), :math:`y_j` (:math:`j \in [0,10^7)`), and that you want to compute a KeOps reduction over indices :math:`i` or :math:`j`: Instead of performing the full kernel dot product (:math:`10^6 \cdot 10^7 = 10^{13}` operations!), you may want to restrict yourself to interactions between points :math:`x_i` and :math:`y_j` that are "close" to each other. With KeOps, the simplest way of doing so is to: 1. Compute cluster labels for the :math:`x_i`'s and :math:`y_j`'s, using e.g. the :func:`grid_cluster` method. 2. Compute the ranges (**ranges_i**, **ranges_j**) and centroids associated to each cluster, using e.g. the :func:`cluster_ranges_centroids` method. 3. Sort the tensors ``x_i`` and ``y_j`` with :func:`sort_clusters` to make sure that the clusters are stored contiguously in memory (this step is **critical** for performance on GPUs). At this point: - the :math:`k`-th cluster of :math:`x_i`'s is given by ``x_i[ ranges_i[k,0]:ranges_i[k,1], : ]``, for :math:`k \in [0,M)`, - the :math:`\ell`-th cluster of :math:`y_j`'s is given by ``y_j[ ranges_j[l,0]:ranges_j[l,1], : ]``, for :math:`\ell \in [0,N)`. 4. Compute the :math:`(M,N)` matrix **dist** of pairwise distances between cluster centroids. 5. Apply a threshold on **dist** to generate a boolean matrix ``keep = dist < threshold``. 6. Define a KeOps reduction ``my_genred = Genred(..., axis = 0 or 1)``, as usual. 7. Compute the block-sparse reduction through ``result = my_genred(x_i, y_j, ranges = from_matrix(ranges_i,ranges_j,keep) )`` :func:`from_matrix` is thus the routine that turns a **high-level description** of your block-sparse computation (cluster ranges + boolean matrix) into a set of **integer tensors** (the **ranges** optional argument), used by KeOps to schedule computations on the GPU. Args: ranges_i ((M,2) IntTensor): List of :math:`[\text{start}_k,\text{end}_k)` indices. For :math:`k \in [0,M)`, the :math:`k`-th cluster of ":math:`i`" variables is given by ``x_i[ ranges_i[k,0]:ranges_i[k,1], : ]``, etc. ranges_j ((N,2) IntTensor): List of :math:`[\text{start}_\ell,\text{end}_\ell)` indices. For :math:`\ell \in [0,N)`, the :math:`\ell`-th cluster of ":math:`j`" variables is given by ``y_j[ ranges_j[l,0]:ranges_j[l,1], : ]``, etc. keep ((M,N) BoolTensor): If the output ``ranges`` of :func:`from_matrix` is used in a KeOps reduction, we will only compute and reduce the terms associated to pairs of "points" :math:`x_i`, :math:`y_j` in clusters :math:`k` and :math:`\ell` if ``keep[k,l] == 1``. Returns: A 6-uple of LongTensors that can be used as an optional **ranges** argument of :class:`torch.Genred <pykeops.torch.Genred>`. See the documentation of :class:`torch.Genred <pykeops.torch.Genred>` for reference. Example: >>> r_i = torch.IntTensor( [ [2,5], [7,12] ] ) # 2 clusters: X[0] = x_i[2:5], X[1] = x_i[7:12] >>> r_j = torch.IntTensor( [ [1,4], [4,9], [20,30] ] ) # 3 clusters: Y[0] = y_j[1:4], Y[1] = y_j[4:9], Y[2] = y_j[20:30] >>> x,y = torch.Tensor([1., 0.]), torch.Tensor([1.5, .5, 2.5]) # dummy "centroids" >>> dist = (x[:,None] - y[None,:])**2 >>> keep = (dist <= 1) # (2,3) matrix >>> print(keep) tensor([[1, 1, 0], [0, 1, 0]], dtype=torch.uint8) --> X[0] interacts with Y[0] and Y[1], X[1] interacts with Y[1] >>> (ranges_i,slices_i,redranges_j, ranges_j,slices_j,redranges_i) = from_matrix(r_i,r_j,keep) --> (ranges_i,slices_i,redranges_j) will be used for reductions with respect to "j" (axis=1) --> (ranges_j,slices_j,redranges_i) will be used for reductions with respect to "i" (axis=0) Information relevant if **axis** = 1: >>> print(ranges_i) # = r_i tensor([[ 2, 5], [ 7, 12]], dtype=torch.int32) --> Two "target" clusters in a reduction wrt. j >>> print(slices_i) tensor([2, 3], dtype=torch.int32) --> X[0] is associated to redranges_j[0:2] --> X[1] is associated to redranges_j[2:3] >>> print(redranges_j) tensor([[1, 4], [4, 9], [4, 9]], dtype=torch.int32) --> For X[0], i in [2,3,4], we'll reduce over j in [1,2,3] and [4,5,6,7,8] --> For X[1], i in [7,8,9,10,11], we'll reduce over j in [4,5,6,7,8] Information relevant if **axis** = 0: >>> print(ranges_j) tensor([[ 1, 4], [ 4, 9], [20, 30]], dtype=torch.int32) --> Three "target" clusters in a reduction wrt. i >>> print(slices_j) tensor([1, 3, 3], dtype=torch.int32) --> Y[0] is associated to redranges_i[0:1] --> Y[1] is associated to redranges_i[1:3] --> Y[2] is associated to redranges_i[3:3] = no one... >>> print(redranges_i) tensor([[ 2, 5], [ 2, 5], [ 7, 12]], dtype=torch.int32) --> For Y[0], j in [1,2,3], we'll reduce over i in [2,3,4] --> For Y[1], j in [4,5,6,7,8], we'll reduce over i in [2,3,4] and [7,8,9,10,11] --> For Y[2], j in [20,21,...,29], there is no reduction to be done """ I, J = torch.meshgrid( (torch.arange(0, keep.shape[0]), torch.arange(0,keep.shape[1])) ) redranges_i = ranges_i[ I.t()[keep.t()] ] # Use PyTorch indexing to "stack" copies of ranges_i[...] redranges_j = ranges_j[ J[keep] ] slices_i = keep.sum(1).cumsum(0).int() # slice indices in the "stacked" array redranges_j slices_j = keep.sum(0).cumsum(0).int() # slice indices in the "stacked" array redranges_i return (ranges_i, slices_i, redranges_j, ranges_j, slices_j, redranges_i)
def make_material_atlas(image: torch.Tensor, faces_verts_uvs: torch.Tensor, texture_size: int) -> torch.Tensor: r""" Given a single texture image and the uv coordinates for all the face vertices, create a square texture map per face using the formulation from [1]. For a triangle with vertices (v0, v1, v2) we can create a barycentric coordinate system with the x axis being the vector (v1 - v0) and the y axis being the vector (v2 - v0). The barycentric coordinates range from [0, 1] in the +x and +y direction so this creates a triangular texture space with vertices at (0, 1), (0, 0) and (1, 0). The per face texture map is of shape (texture_size, texture_size, 3) which is a square. To map a triangular texture to a square grid, each triangle is parametrized as follows (e.g. R = texture_size = 3): The triangle texture is first divided into RxR = 9 subtriangles which each map to one grid cell. The numbers in the grid cells and triangles show the mapping. ..code-block::python Triangular Texture Space: 1 |\ |6 \ |____\ |\ 7 |\ |3 \ |4 \ |____\|____\ |\ 8 |\ 5 |\ |0 \ |1 \ |2 \ |____\|____\|____\ 0 1 Square per face texture map: R ____________________ | | | | | 6 | 7 | 8 | |______|______|______| | | | | | 3 | 4 | 5 | |______|______|______| | | | | | 0 | 1 | 2 | |______|______|______| 0 R The barycentric coordinates of each grid cell are calculated using the xy coordinates: ..code-block::python The cartesian coordinates are: Grid 1: R ____________________ | | | | | 20 | 21 | 22 | |______|______|______| | | | | | 10 | 11 | 12 | |______|______|______| | | | | | 00 | 01 | 02 | |______|______|______| 0 R where 02 means y = 0, x = 2 Now consider this subset of the triangle which corresponds to grid cells 0 and 8: ..code-block::python 1/R ________ |\ 8 | | \ | | 0 \ | |_______\| 0 1/R The centroids of the triangles are: 0: (1/3, 1/3) * 1/R 8: (2/3, 2/3) * 1/R For each grid cell we can now calculate the centroid `(c_y, c_x)` of the corresponding texture triangle: - if `(x + y) < R`, then offsett the centroid of triangle 0 by `(y, x) * (1/R)` - if `(x + y) > R`, then offset the centroid of triangle 8 by `((R-1-y), (R-1-x)) * (1/R)`. This is equivalent to updating the portion of Grid 1 above the diagnonal, replacing `(y, x)` with `((R-1-y), (R-1-x))`: ..code-block::python R _____________________ | | | | | 20 | 01 | 00 | |______|______|______| | | | | | 10 | 11 | 10 | |______|______|______| | | | | | 00 | 01 | 02 | |______|______|______| 0 R The barycentric coordinates (w0, w1, w2) are then given by: ..code-block::python w0 = c_x w1 = c_y w2 = 1- w0 - w1 Args: image: FloatTensor of shape (H, W, 3) faces_verts_uvs: uv coordinates for each vertex in each face (F, 3, 2) texture_size: int Returns: atlas: a FloatTensor of shape (F, texture_size, texture_size, 3) giving a per face texture map. [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based 3D Reasoning', ICCV 2019 """ R = texture_size device = faces_verts_uvs.device rng = torch.arange(R, device=device) # Meshgrid returns (row, column) i.e (Y, X) # Change order to (X, Y) to make the grid. Y, X = torch.meshgrid(rng, rng) # pyre-fixme[28]: Unexpected keyword argument `axis`. grid = torch.stack([X, Y], axis=-1) # (R, R, 2) # Grid cells below the diagonal: x + y < R. below_diag = grid.sum(-1) < R # map a [0, R] grid -> to a [0, 1] barycentric coordinates of # the texture triangle centroids. bary = torch.zeros((R, R, 3), device=device) # (R, R, 3) slc = torch.arange(2, device=device)[:, None] # w0, w1 bary[below_diag, slc] = ((grid[below_diag] + 1.0 / 3.0) / R).T # w0, w1 for above diagonal grid cells. # pyre-fixme[16]: `float` has no attribute `T`. bary[~below_diag, slc] = (((R - 1.0 - grid[~below_diag]) + 2.0 / 3.0) / R).T # w2 = 1. - w0 - w1 bary[..., -1] = 1 - bary[..., :2].sum(dim=-1) # Calculate the uv position in the image for each pixel # in the per face texture map # (F, 1, 1, 3, 2) * (R, R, 3, 1) -> (F, R, R, 3, 2) -> (F, R, R, 2) uv_pos = (faces_verts_uvs[:, None, None] * bary[..., None]).sum(-2) # bi-linearly interpolate the textures from the images # using the uv coordinates given by uv_pos. textures = _bilinear_interpolation_vectorized(image, uv_pos) return textures
def forward(self, x): x = self.model(x) return x model = NETWORK() # model.load_state_dict(torch.load("E:/python/Jupyter/TrainedModels_Saving/sinhtcosx_for_spare_use.pth")) model.load_state_dict(torch.load("./nn_proportional_delay_PDE.pth")) #plot 3-D figure N_t, N_x = 500, 100 t = torch.linspace(0, 1.0, N_t) x = torch.linspace(0, 2 * pi, N_x) t, x = torch.meshgrid(t, x) T_X = torch.zeros(N_t * N_x, 2) for i in range(N_t): for j in range(N_x): T_X[i * N_x + j, :] = torch.tensor([t[i, j], x[i, j]]) u = model(T_X).detach().numpy().reshape(N_t, N_x) t = t.numpy() x = x.numpy() u_acc = np.sinh(t) * np.cos(x) fig = plt.figure() fig.set_size_inches(10, 4) ax0 = fig.add_subplot(1, 2, 1, projection='3d') ax0.plot_surface(t, x, u, rstride=1, cstride=1, cmap='rainbow') # ax0.grid(b = False)
def forward(self, box_cls_all, box_reg_all, centerness_all, boxes_all): device = box_cls_all.device boxes_per_image = [len(box) for box in boxes_all] cls = box_cls_all.split(boxes_per_image, dim=0) reg = box_reg_all.split(boxes_per_image, dim=0) center = centerness_all.split(boxes_per_image, dim=0) results = [] for box_cls, box_regression, centerness, boxes in zip(cls, reg, center, boxes_all): N, C, H, W = box_cls.shape # put in the same format as locations box_cls = box_cls.permute(0, 2, 3, 1).reshape(N, -1, self.num_classes).sigmoid() box_regression = box_regression.permute(0, 2, 3, 1).reshape(N, -1, 4) centerness = centerness.permute(0, 2, 3, 1).reshape(N, -1).sigmoid() # multiply the classification scores with centerness scores box_cls = box_cls * centerness[:, :, None] _boxes = boxes.bbox size = boxes.size boxes_scores = boxes.get_field("scores") results_per_image = [boxes] for i in range(N): box = _boxes[i] boxes_score = boxes_scores[i] per_box_cls = box_cls[i] per_box_cls_max, per_box_cls_inds = per_box_cls.max(dim=0) per_class = torch.range(2, 1 + self.num_classes, dtype=torch.long, device=device) per_box_regression = box_regression[i] per_box_regression = per_box_regression[per_box_cls_inds] x_step = 1.0 y_step = 1.0 shifts_x = torch.arange( 0, self.m, step=x_step, dtype=torch.float32, device=device ) + x_step / 2 shifts_y = torch.arange( 0, self.m, step=y_step, dtype=torch.float32, device=device ) + y_step / 2 shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) locations = torch.stack((shift_x, shift_y), dim=1) per_locations = locations[per_box_cls_inds] _x1 = per_locations[:, 0] - per_box_regression[:, 0] _y1 = per_locations[:, 1] - per_box_regression[:, 1] _x2 = per_locations[:, 0] + per_box_regression[:, 2] _y2 = per_locations[:, 1] + per_box_regression[:, 3] _x1 = _x1 / self.m * (box[2] - box[0]) + box[0] _y1 = _y1 / self.m * (box[3] - box[1]) + box[1] _x2 = _x2 / self.m * (box[2] - box[0]) + box[0] _y2 = _y2 / self.m * (box[3] - box[1]) + box[1] detections = torch.stack([_x1, _y1, _x2, _y2], dim=-1) boxlist = BoxList(detections, size, mode="xyxy") boxlist.add_field("labels", per_class) boxlist.add_field("scores", torch.sqrt(torch.sqrt(per_box_cls_max) * boxes_score)) boxlist = boxlist.clip_to_image(remove_empty=False) boxlist = remove_small_boxes(boxlist, 0) results_per_image.append(boxlist) results_per_image = cat_boxlist(results_per_image) results.append(results_per_image) return results
def optimize(ux0, uy0, im1, im2, maxIter, lambda_r, to1, theta, device): eps = 0.0000001 Ix, Iy = centralFiniteDifference(im1) It = im1 - im2 a11 = Ix * Ix a12 = Ix * Iy a22 = Iy * Iy t1 = Ix * (It - Ix * ux0 - Iy * uy0) t2 = Iy * (It - Ix * ux0 - Iy * uy0) h, w = im1.size() vx = torch.zeros((h, w)).to(device).double() vy = torch.zeros((h, w)).to(device).double() bx = torch.zeros((h, w)).to(device).double() by = torch.zeros((h, w)).to(device).double() ux = torch.zeros((h, w)).to(device).double() uy = torch.zeros((h, w)).to(device).double() X, Y = torch.meshgrid([torch.arange(0, h), torch.arange(0, w)]) # G = 2 * (torch.cos(PI * X / w + PI * Y / h) - 2) # G = G.to(device).double() X, Y = torch.meshgrid(torch.linspace(0, h - 1, h), torch.linspace(0, w - 1, w)) X, Y = X.cuda(), Y.cuda() G = torch.cos(math.pi * X / h) + torch.cos(math.pi * Y / w) - 2 # G = G.unsqueeze(0).repeat(N, 1, 1, 1) for i in range(maxIter): tempx = ux tempy = uy h1 = theta * (vx - bx) - t1 h2 = theta * (vy - by) - t2 ux = ((a22 + theta) * h1 - a12 * h2) / ((a11 + theta) * (a22 + theta) - a12 * a12) uy = ((a11 + theta) * h2 - a12 * h1) / ((a11 + theta) * (a22 + theta) - a12 * a12) # vx = (idct2(dct2(theta * (ux + bx)) / (theta + lambda_r * G * G))) # vy = (idct2(dct2(theta * (uy + by)) / (theta + lambda_r * G * G))) vx = (tdct.idct_2d( tdct.dct_2d(theta * (ux + bx)) / (theta + lambda_r * G * G))) vy = (tdct.idct_2d( tdct.dct_2d(theta * (uy + by)) / (theta + lambda_r * G * G))) bx = bx + ux - vx by = by + uy - vy # t1 = Ix * (It - Ix * ux - Iy * uy) # t2 = Iy * (It - Ix * ux - Iy * uy) stopx = torch.sum( torch.abs(ux - tempx)) / (torch.sum(torch.abs(tempx)) + eps) stopy = torch.sum( torch.abs(uy - tempy)) / (torch.sum(torch.abs(tempy)) + eps) # print(i, stopx, stopy) if stopx < to1 and stopy < to1: print('iterate {} times, stop due to converge to tolerance'.format( i)) break if i == maxIter - 1: print('iterate {} times, stop due to reach max iteration'.format(i)) return ux, uy