def get_prediction(self, batch: Dict[str, Tensor]): """Return loss, accuracy, binary probabilities, and MAP classifications for given batch.""" galaxy_bools = batch["galaxy_bools"].reshape(-1) locs = rearrange(batch["locs"], "n nth ntw ms hw -> 1 (n nth ntw) ms hw") image_ptiles = get_images_in_tiles( torch.cat((batch["images"], batch["background"]), dim=1), self.tile_slen, self.ptile_slen, ) image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w") galaxy_probs = self.forward(image_ptiles, locs) galaxy_probs = galaxy_probs.reshape(-1) tile_is_on_array = get_is_on_from_n_sources(batch["n_sources"], self.max_sources) tile_is_on_array = tile_is_on_array.reshape(-1) # we need to calculate cross entropy loss, only for "on" sources loss = BCELoss(reduction="none")(galaxy_probs, galaxy_bools) * tile_is_on_array loss = loss.sum() # get predictions for calculating metrics pred_galaxy_bools = (galaxy_probs > 0.5).float() * tile_is_on_array correct = ((pred_galaxy_bools.eq(galaxy_bools)) * tile_is_on_array).sum() total_n_sources = batch["n_sources"].sum() acc = correct / total_n_sources # finally organize quantities and return as a dictionary pred_star_bools = (1 - pred_galaxy_bools) * tile_is_on_array return { "loss": loss, "acc": acc, "galaxy_bools": pred_galaxy_bools, "star_bools": pred_star_bools, "galaxy_probs": galaxy_probs, }