def plot(self, images): perrow = 5 num, c, w, h = images.size() rows = int(math.ceil(num/perrow)) prep = self.preprocess(images) b, g, _= prep.size() ci, hi, wi = self.in_shape hg, wg = self.glimpse_size means = prep[:, :, :2] means = sparse.transform_means(means, (hi-hg, wi-wg)).round().long() sigmas = prep[:, : , 2] sigmas = sparse.transform_sigmas(sigmas, self.in_shape[1:]) images = images.data plt.figure(figsize=(perrow * 3, rows*3)) for i in range(num): ax = plt.subplot(rows, perrow, i+1) im = np.transpose(images[i, :, :, :].cpu().numpy(), (1, 2, 0)) im = np.squeeze(im) ax.imshow(im, interpolation='nearest', extent=(-0.5, w-0.5, -0.5, h-0.5), cmap='gray_r') util.plot(means[i, :, :].unsqueeze(0), sigmas[i, :, :].unsqueeze(0), torch.ones(means[:, :, 0].size()), axes=ax, flip_y=h, alpha_global=0.8/self.num_glimpses) plt.gcf()
def forward(self, image): prep = self.preprocess(image) b, g, _= prep.size() ci, hi, wi = self.in_shape hg, wg = self.glimpse_size means = prep[:, :, :2] sigmas = prep[:, : , 2] sigmas = sparse.transform_sigmas(sigmas, self.in_shape[1:]) stoch_means = torch.distributions.Normal(means, sigmas) sample = stoch_means.sample() point_means = sparse.transform_means(sample, (hi-hg, wi-wg)).round().long() # extract batch = [] for bi in range(b): extracts = [] for gi in range(g): h, w = point_means[bi, gi, :] ext = image.data[bi, :, h:h+hg, w:w+wg] extracts.append(ext[None, None, :, :, :]) batch.append(torch.cat(extracts, dim=1)) result = torch.cat(batch, dim=0) return result, stoch_means, sample
def hyper(self, input, prep=None): """ Evaluates hypernetwork. """ b, c, h, w = input.size() quad = self.preprocess(input) # Cpompute the attention quadrangle b, g, _, _ = quad.size() # (b, g, 4, 2) k, k, _ = self.grid.size() # k, k, 4 # ensure that the bounding box covers a reasonable area of the image at the start quad = quad + self.quad_offset[None, None, :] # Fit to the max pixel values quad = sparse.transform_means(quad.view(b, g*4, 2), (h, w)).view(b, g, 4, 2) # Interpolate between the four corners of the quad grid = self.grid[None, None, :, :, :, None] # b, g, k, k, 4, 2 quad = quad[:, :, None, None, :, :] res = (grid * quad).sum(dim=4) assert res.size() == (b, g, k, k, 2) means = res.view(b, g * k * k, 2) # Expand sigmas sigmas = self.sigmas[None, :, None].expand(b, g, (k*k)).contiguous().view(b, (g*k*k)) sigmas = sparse.transform_sigmas(sigmas, (h, w)) sigmas = sigmas * self.sigma_scale + self.min_sigma values = self.one[None, :].expand(b, k*k*g) return means, sigmas, values
def hyper(self, x): assert x.size()[1:] == self.in_size b, c, h, w = x.size() k = self.k # the coordinates of the current pixels in parameters space # - the index tuples are described relative to these hw = torch.tensor((h, w), device=d(x), dtype=torch.float) mids = self.coords[None, :, :, :].expand( b, 2, h, w) * (hw - 1)[None, :, None, None] mids = mids.permute(0, 2, 3, 1) if not self.modulo: mids = util.inv(mids, mx=hw[None, None, None, :]) mids = mids[:, :, :, None, :].expand(b, h, w, k, 2) # add coords to channels if self.admode == 'none': params = self.params[None, None, None, :].expand(b, h, w, k * 3) else: if self.admode == 'full': coords = self.coords[None, :, :, :].expand(b, 2, h, w) x = torch.cat([x, coords], dim=1) elif self.admode == 'coords': x = self.coords[None, :, :, :].expand(b, 2, h, w) elif self.admode == 'inputs': pass else: raise Exception( f'adaptivity mode {self.admode} not recognized') x = x.permute(0, 2, 3, 1) params = self.toparams(x) assert params.size() == (b, h, w, k * 3 ) # k index tuples per output pixel means = params[:, :, :, :k * 2].view(b, h, w, k, 2) sigmas = params[:, :, :, k * 2:].view(b, h, w, k) values = self.mvalues[None, None, None, :].expand(b, h, w, k) means = mids + self.mmult * means s = (h, w) means = sparse.transform_means( means, s, method='modulo' if self.modulo else 'sigmoid') sigmas = sparse.transform_sigmas( sigmas, s, min_sigma=self.min_sigma) * self.sigma_scale return means, sigmas, values
def hyper(self, x): b = x.size(0) size = (self.size, self.size) # Expand parameters along batch dimension means = self.pmeans[None, :, :].expand(b, self.size, 2) sigmas = self.psigmas[None, :].expand(b, self.size) values = self.pvalues[None, :].expand(b, self.size) means, sigmas = sparse.transform_means(means, size), sparse.transform_sigmas( sigmas, size) return means, sigmas, values
def hyper(self, input, prep=None): """ Evaluates hypernetwork. """ b, c, h, w = input.size() thetas = self.preprocess(input) # Cpompute the attention quadrangle b, g, _, _ = thetas.size() # (b, g, 2, 3) thetas = thetas * self.scale + self.identity[None, None, :, :] mat, vec = thetas[:, :, :, :2], thetas[:, :, :, 2] corners = self.corners[None, None, :, :].expand(b, g, 4, 2) # compute the corners of the attention quad quad = torch.bmm(corners.view(b*g, 4, 2), mat.view(b*g, 2, 2)) + vec.view(b*g, 1, 2) k, k, _ = self.grid.size() # k, k, 4 # Fit to the max pixel values quad = sparse.transform_means(quad.view(b, g*4, 2), (h, w)).view(b, g, 4, 2) # Interpolate between the four corners of the quad grid = self.grid[None, None, :, :, :, None] # b, g, k, k, 4, 2 quad = quad[:, :, None, None, :, :] res = (grid * quad).sum(dim=4) assert res.size() == (b, g, k, k, 2) means = res.view(b, g * k * k, 2) # Expand sigmas sigmas = self.sigmas[None, :, None].expand(b, g, (k*k)).contiguous().view(b, (g*k*k)) sigmas = sparse.transform_sigmas(sigmas, (h, w)) sigmas = sigmas * self.sigma_scale + self.min_sigma values = self.one[None, :].expand(b, k*k*g) return means, sigmas, values
def hyper(self, input, prep=None): """ Evaluates hypernetwork. """ b, c, h, w = input.size() bboxes = self.preprocess(input) # (b, g, 4) b, g, _ = bboxes.size() # ensure that the bounding box covers a reasonable area of the image at the start bboxes = bboxes + self.bbox_offset[None, None, :] # Fit to the max pixel values bboxes = sparse.transform_means(bboxes, (h, h, w, w)) vmin, vmax, hmin, hmax = bboxes[:, :, 0], bboxes[:, :, 1], bboxes[:, :, 2], bboxes[:, :, 3] # vert (height), hor (width), vrange, hrange = vmax - vmin, hmax - hmin pih, _ = self.pixel_indices.size() pixel_indices = self.pixel_indices.view(g, pih // g, 2) pixel_indices = pixel_indices[None, :, :, :].expand(b, g, pih // g, 2) range = torch.cat([vrange[:, :, None, None], hrange[:, :, None, None]], dim=3) range = range.expand(b, g, pih//g, 2) min = torch.cat([vmin[:, :, None, None], hmin[:, :, None, None]], dim=3) min = min.expand(b, g, pih//g, 2) means = pixel_indices * range + min means = means.view(b, pih, 2) # Expand sigmas sigmas = self.sigmas[None, :, None].expand(b, g, pih//g).contiguous().view(b, pih) sigmas = sparse.transform_sigmas(sigmas, (h, w)) sigmas = sigmas * self.sigma_scale + self.min_sigma values = self.one[None, :].expand(b, pih) return means, sigmas, values
def forward(self, x): b, c, h, w = x.size() params = self.encoder(x) ls = self.latent.size() s, e = ls[:-1], ls[-1] assert params.size() == (b, len(ls)) means = sparse.transform_means(params[:, None, None, :-1], s, method=self.method) sigmas = sparse.transform_sigmas( params[:, None, None, -1], s, min_sigma=self.min_sigma) * self.sigma_scale if self.smp: indices = sparse.ngenerate(means, self.gadditional, self.radditional, rng=s, relative_range=self.region, cuda=x.is_cuda) vs = (2**len(s) + self.radditional + self.gadditional) assert indices.size() == ( b, 1, vs, len(s)), f'{indices.size()}, {(b, 1, vs, len(s))}' indfl = indices.float() # Mask for duplicate indices dups = util.nduplicates(indices).to(torch.bool) # compute (unnormalized) densities under the given MVNs (proportions) props = sparse.densities(indfl, means, sigmas).clone() assert props.size() == (b, 1, vs, 1) #? props[dups, :] = 0 props = props / props.sum( dim=2, keepdim=True ) # normalize over all points of a given index tuple weights = props.sum(dim=-1) # - sum out the MVNs assert indices.size() == (b, 1, vs, len(s)) assert weights.size() == (b, 1, vs) indices, weights = indices.squeeze(1), weights.squeeze(1) else: vs = 1 indices = means.floor().to(torch.long).detach().squeeze(1) # Select a single code from the latent space (per instance in batch). # When sampling, this is a weighted sum, when not sampling, just one. indices = indices.view(b * vs, len(s)) # checks to prevent segfaults if util.contains_nan(indices): print(params) raise Exception('Indices contain NaN') if indices[:, 0].max() >= s[0] or indices[:, 1].max() >= s[1]: print(indices.max()) print(params) raise Exception('Indices out of bounds') if len(s) == 1: code = self.latent[indices[:, 0], :] elif len(s) == 2: code = self.latent[indices[:, 0], indices[:, 1], :] elif len(s) == 3: code = self.latent[indices[:, 0], indices[:, 1], indices[:, 2], :] else: raise Exception(f'Dimensionality above 3 not supported.') # - ugly hack, until I figure out how to do this for n dimensions assert code.size() == (b * vs, e), f'{code.size()} --- {(b*vs, e)}' if self.smp: code = code.view(b, vs, e) code = code * weights[:, :, None] code = code.sum(dim=1) else: code = code.view(b, e) assert code.size() == (b, e) # Decode result = self.decoder(code) assert result.size() == (b, c, h, w), f'{result.size()} --- {(b, c, h, w)}' return result