def forward(self, x, heatmap=None): """ x: (batch, c, x_dim, y_dim) """ coords = self.coords.repeat(x.size(0), 1, 1, 1) if self.with_boundary and heatmap is not None: boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) zero_tensor = torch.zeros_like(self.x_coords) xx_boundary_channel = torch.where(boundary_channel > 0.05, self.x_coords, zero_tensor).to( zero_tensor.device) yy_boundary_channel = torch.where(boundary_channel > 0.05, self.y_coords, zero_tensor).to( zero_tensor.device) coords = torch.cat([coords, xx_boundary_channel, yy_boundary_channel], dim=1) x_and_coords = torch.cat([x, coords], dim=1) return x_and_coords
def truncate(x, thres=0.1): """Remove small values in heatmaps.""" return porch.where(x < thres, porch.zeros_like(x), x)