def __init__(self, network_out: prob_clf.NetworkOutput,
              x_r: NormalizedTensor, x_l: NormalizedTensor):
     self.network_out = network_out
     self.x_r = x_r.to_sym()
     res = self.x_r.t - x_l.to_sym().t
     self.res_sym = SymbolTensor(res, L=511, centered=True)
     self.res = self.res_sym.to_norm()
     self._mean_img = None
示例#2
0
    def forward(self, x_n: NormalizedTensor, side_information=None) -> prob_clf.NetworkOutput:
        self.summarizer.register_images('train', {'input': lambda: x_n.to_sym().get().to(torch.uint8)}, only_once=True)
        x = x_n.get()
        x = self.head(x)
        x_after_head = x
        x = self.down(x)
        x = self.body(x) + x
        x = self.after_skip(x)  # goes up again

        if self.unet_skip_conv is not None:
            x = self.unet_skip_conv(torch.cat((x, x_after_head), dim=1))

        if self.side_information_mode:
            x = self.side_information_conv(x, side_information)

        return self.tail(x)
    def _plot_cdf(self,
                  x: NormalizedTensor,
                  logit_pis,
                  means,
                  log_scales,
                  axs,
                  x_range=None,
                  y_range=None,
                  num_per_series=2):
        """
        :param x: NC1HW
        :param logit_pis:
        :param means:
        :param log_scales:
        :param axs: where to plot to
        :param num_per_series:
        :return:
        """

        _, _, _, H, W = x.get().shape
        if not x_range:
            assert num_per_series > 0
            x_range = _get_series_range(H, num_per_series)
        if not y_range:
            assert num_per_series > 0
            y_range = _get_series_range(W, num_per_series)

        # NCKHW -> 1CK44
        cut = lambda t_: t_.detach()[:1, :, :,
                                     slice(*x_range),
                                     slice(*y_range)]
        logit_pis, means, log_scales = map(cut, (logit_pis, means, log_scales))

        # Get first element in batch
        cdf = self._get_cdf(logit_pis, means,
                            log_scales).detach()[0, ...].cpu().numpy()

        C, H, W, Lp = cdf.shape
        # CHW
        sym = x.to_sym().get().detach()[0, :, 0,
                                        slice(*x_range),
                                        slice(*y_range)].cpu().numpy()

        blow_up = 8
        x_vis = x.get()[0, :, 0, slice(*x_range),
                        slice(*y_range)].permute(1, 2, 0)
        x_vis = normalize_to_0_1(x_vis).mul(255).round().to(torch.uint8)
        x_vis = x_vis.detach().cpu().numpy().repeat(blow_up,
                                                    axis=0).repeat(blow_up,
                                                                   axis=1)

        # offset = 0
        # if self.L > 256:
        #     offset = 256
        # targets = np.arange(Lp) - offset
        # print(len(targets))

        gts = set()
        for c, (ax_a, ax) in enumerate(axs):
            print(ax)
            ax_a.imshow(x_vis[..., c], cmap='gray')
            for x in range(H):
                for y in range(W):
                    cdf_xy = cdf[c, x, y, 1:] - cdf[c, x, y, :-1]
                    p = ax.plot(np.arange(Lp - 1),
                                cdf_xy,
                                linestyle='-',
                                linewidth=0.5)
                    ax.set_ylim(-0.1, 1.1)
                    gt = sym[c, x, y]
                    gts.add(gt)
                    ax.axvline(gt,
                               color=p[-1].get_color(),
                               linestyle='--',
                               linewidth=0.5)
        for _, ax in axs:
            ax.set_xlim(min(gts) - 5, max(gts) + 5)