def raster_soft(self, edt2, conn): rasters = exp(edt2, self.sigma2) connect = conn.view(edt2.shape[0], -1, 1, 1) if self.hard: connect = connect + ((connect > 0.5).float() - connect).detach() rasters = rasters * connect return softor(rasters, keepdim=True)
def render_crs(params, sigma2, grid, coordpairs): ncrs = params.shape[0] crs = torch.cat((params[:, coordpairs[:, 0]], params[:, coordpairs[:, 1]], params[:, coordpairs[:, 2]], params[:, coordpairs[:, 3]]), dim=-1) # [batch, nlines, 8] crs = crs.view(ncrs, -1, 4, 2) return softor(exp( curve_edt2_polyline(crs, grid, 10, cfcn=centripetal_catmull_rom_spline), sigma2), dim=1).unsqueeze(0)
def render(params, cparams, sigma2, grid, coordpairs, args): ras = [] if args.points > 0: pparams = params[0:2 * args.points].view(args.points, 2) if not isinstance(sigma2, torch.Tensor): pts = render_points(pparams, sigma2, grid) else: pts = render_points(pparams, sigma2[0:args.points], grid) ras.append(pts) if args.lines > 0: lparams = params[2 * args.points:2 * args.points + 4 * args.lines].view(args.lines, 2, 2) if not isinstance(sigma2, torch.Tensor): lns = render_lines(lparams, sigma2, grid) else: lns = render_lines(lparams, sigma2[args.points:args.points + args.lines], grid) ras.append(lns) if args.crs > 0: crsparams = params[2 * args.points + 4 * args.lines:].view( args.crs, 2 + args.crs_points, 2) if not isinstance(sigma2, torch.Tensor): crs = render_crs(crsparams, sigma2, grid, coordpairs) else: crs = render_crs(crsparams, sigma2[args.points + args.lines:], grid, coordpairs) ras.append(crs) ras = torch.cat(ras, dim=0) # [1, nprim, row, col] if cparams is not None: ras = ras.unsqueeze(2) # [1, nprim, 1, row, col] ras = ras.repeat_interleave(3, dim=2) # [1, nprim, 3, row, col] lab = cparams.unsqueeze(-1).unsqueeze(-1) # npts, 4, 1, 1 ras = lab * ras return over(ras, dim=1, keepdim=False) # return over_recursive(ras, dim=1) return softor(ras, dim=1, keepdim=True)
def forward(self, inp, state=None): bs = inp.shape[0] composite = torch.zeros(bs, 1, self.sz, self.sz, device=inp.device) lines = [] edt2 = None for i in range(self.steps): prev_latent = self.encoder(composite) latent = torch.cat((inp, prev_latent), dim=-1) params, conn = self.decode_to_params(latent) lines.append(params) # is this correct here? edt2 = self.create_edt2(params) ras = self.raster_soft(edt2, conn) composite = softor(torch.cat((composite, ras), dim=1)).unsqueeze(1) if state is not None: state[metrics.HARDRASTER] = self.raster_hard(edt2) state[metrics.SQ_DISTANCE_TRANSFORM] = edt2 return composite
def raster_soft(self, edt2, conn): rasters = exp(edt2, self.sigma2) rasters = rasters * conn.view(edt2.shape[0], -1, 1, 1) return softor(rasters, keepdim=True)
def raster_hard(self, edt2): rasters = nearest_neighbour(edt2.detach(), compute_nearest_neighbour_sigma_bres(self.grid)) return softor(rasters, keepdim=True)
def raster_soft(self, edt2, sigma2): rasters = exp(edt2, sigma2) return softor(rasters, keepdim=True)