def __init__(self, network_out: prob_clf.NetworkOutput, x_r: NormalizedTensor, x_l: NormalizedTensor): self.network_out = network_out self.x_r = x_r.to_sym() res = self.x_r.t - x_l.to_sym().t self.res_sym = SymbolTensor(res, L=511, centered=True) self.res = self.res_sym.to_norm() self._mean_img = None
def forward(self, x_n: NormalizedTensor, side_information=None) -> prob_clf.NetworkOutput: self.summarizer.register_images('train', {'input': lambda: x_n.to_sym().get().to(torch.uint8)}, only_once=True) x = x_n.get() x = self.head(x) x_after_head = x x = self.down(x) x = self.body(x) + x x = self.after_skip(x) # goes up again if self.unet_skip_conv is not None: x = self.unet_skip_conv(torch.cat((x, x_after_head), dim=1)) if self.side_information_mode: x = self.side_information_conv(x, side_information) return self.tail(x)
def _plot_cdf(self, x: NormalizedTensor, logit_pis, means, log_scales, axs, x_range=None, y_range=None, num_per_series=2): """ :param x: NC1HW :param logit_pis: :param means: :param log_scales: :param axs: where to plot to :param num_per_series: :return: """ _, _, _, H, W = x.get().shape if not x_range: assert num_per_series > 0 x_range = _get_series_range(H, num_per_series) if not y_range: assert num_per_series > 0 y_range = _get_series_range(W, num_per_series) # NCKHW -> 1CK44 cut = lambda t_: t_.detach()[:1, :, :, slice(*x_range), slice(*y_range)] logit_pis, means, log_scales = map(cut, (logit_pis, means, log_scales)) # Get first element in batch cdf = self._get_cdf(logit_pis, means, log_scales).detach()[0, ...].cpu().numpy() C, H, W, Lp = cdf.shape # CHW sym = x.to_sym().get().detach()[0, :, 0, slice(*x_range), slice(*y_range)].cpu().numpy() blow_up = 8 x_vis = x.get()[0, :, 0, slice(*x_range), slice(*y_range)].permute(1, 2, 0) x_vis = normalize_to_0_1(x_vis).mul(255).round().to(torch.uint8) x_vis = x_vis.detach().cpu().numpy().repeat(blow_up, axis=0).repeat(blow_up, axis=1) # offset = 0 # if self.L > 256: # offset = 256 # targets = np.arange(Lp) - offset # print(len(targets)) gts = set() for c, (ax_a, ax) in enumerate(axs): print(ax) ax_a.imshow(x_vis[..., c], cmap='gray') for x in range(H): for y in range(W): cdf_xy = cdf[c, x, y, 1:] - cdf[c, x, y, :-1] p = ax.plot(np.arange(Lp - 1), cdf_xy, linestyle='-', linewidth=0.5) ax.set_ylim(-0.1, 1.1) gt = sym[c, x, y] gts.add(gt) ax.axvline(gt, color=p[-1].get_color(), linestyle='--', linewidth=0.5) for _, ax in axs: ax.set_xlim(min(gts) - 5, max(gts) + 5)