Exemplo n.º 1
0
    def _non_shared_sample(self, l: NetworkOutput) -> NormalizedTensor:
        """ sample from model """
        logit_probs = l.pis  # NCKHW
        N, C, K, H, W = logit_probs.shape

        # sample mixture indicator from softmax
        u = torch.zeros_like(logit_probs).uniform_(1e-5, 1. - 1e-5)  # NCKHW
        sel = torch.argmax(
            logit_probs - torch.log(-torch.log(u)),  # gumbel sampling
            dim=2
        )  # argmax over K, results in NCHW, specifies for each c: which of the K mixtures to take
        assert sel.shape == (N, C, H, W), (sel.shape, (N, C, H, W))

        sel = sel.unsqueeze(2)  # NC1HW

        means = torch.gather(l.means, 2, sel).squeeze(2)
        log_scales = torch.clamp(torch.gather(l.sigmas, 2, sel).squeeze(2),
                                 min=self.min_sigma)

        # sample from the resulting logistic, which now has essentially 1 mixture component only.
        # We use inverse transform sampling. i.e. X~logistic; generate u ~ Unfirom; x = CDF^-1(u),
        #  where CDF^-1 for the logistic is CDF^-1(y) = \mu + \sigma * log(y / (1-y))
        u = torch.zeros_like(means).uniform_(1e-5, 1. - 1e-5)  # NCHW
        x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)
                                             )  # NCHW

        if l.lambdas is not None:
            assert C == 3

            clamp = lambda x_: torch.clamp(x_, -1., 1.)

            # Be careful about coefficients! We need to use the correct selection mask, namely the one for the G and
            #  B channels, as we update the G and B means! Doing torch.gather(coeffs, 2, sel) would be completly
            #  wrong.
            coeffs = torch.tanh(l.lambdas)
            sel_g, sel_b = sel[:, 1, ...], sel[:, 2, ...]
            coeffs_g_r = torch.gather(coeffs[:, 0, ...], 1, sel_g).squeeze(1)
            coeffs_b_r = torch.gather(coeffs[:, 1, ...], 1, sel_b).squeeze(1)
            coeffs_b_g = torch.gather(coeffs[:, 2, ...], 1, sel_b).squeeze(1)

            # Note: In theory, we should go step by step over the channels and update means with previously sampled
            # xs. But because of the math above (x = means + ...), we can just update the means here and it's all good.
            x0 = clamp(x[:, 0, ...], )
            x1 = clamp(x[:, 1, ...] + coeffs_g_r * x0)
            x2 = clamp(x[:, 2, ...] + coeffs_b_r * x0 + coeffs_b_g * x1)
            x = torch.stack((x0, x1, x2), dim=1)

        return NormalizedTensor(x, self.L, centered=True)
Exemplo n.º 2
0
    def _decode(self, pin):
        pin_bpg = self._path_for_bpg(pin)
        with self.times.run('BPG'):
            x_l: NormalizedTensor = self._decode_bpg(pin_bpg)
        with open(pin, 'rb') as fin:
            dmll = self.blueprint.losses.loss_dmol_rgb

            with self.times.prefix_scope(f'RGB'):
                with self.times.run('get_P'):
                    actual_bpp = os.path.getsize(pin_bpg) * 8 / np.prod(
                        np.prod(x_l.t.shape) / 3)
                    network_out: prob_clf.NetworkOutput = self.blueprint.forward_lossy(
                        x_l, torch.tensor([actual_bpp], device=pe.DEVICE))
                    # l, dec_out_prev = self.blueprint.net.get_P(
                    #         scale, bn_prev, dec_out_prev)
                # NCHW [-1, 1], residual
                res_decoded = self.decode_rgb(dmll, network_out, fin)
            assert fin.read(4) == _MAGIC_VALUE_SEP  # assert valid file
        assert res_decoded is not None  # assert decoding worked

        res_decoded_sym = NormalizedTensor(res_decoded, L=511,
                                           centered=True).to_sym()
        img = x_l.to_sym().t + res_decoded_sym.t
        return img  # 1CHW int64
Exemplo n.º 3
0
    def _extract_non_shared(self, x: NormalizedTensor, l: NetworkOutput):
        """
        :param x: targets, NCHW
        :param l: output of net, NKpHW, see above
        :return:
            x NC1HW,
            logit_probs NCKHW (probabilites of scales, i.e., \pi_k)
            means NCKHW,
            log_scales NCKHW (variances),
            K (number of mixtures)
        """
        x_raw = x.get()

        N, C, H, W = x_raw.shape

        logit_probs = l.pis  # NCKHW
        means = l.means  # NCKHW
        log_scales = torch.clamp(l.sigmas,
                                 min=self.min_sigma)  # NCKHW, is >= -MIN_SIGMA

        x_raw = x_raw.reshape(N, C, 1, H, W)

        if l.lambdas is not None:
            assert C == 3  # Coefficients only supported for C==3, see note where we define _NUM_PARAMS_RGB
            coeffs = torch.tanh(
                l.lambdas
            )  # NCKHW, basically coeffs_g_r, coeffs_b_r, coeffs_b_g
            means_r, means_g, means_b = means[:, 0,
                                              ...], means[:, 1,
                                                          ...], means[:, 2,
                                                                      ...]  # each NKHW
            coeffs_g_r, coeffs_b_r, coeffs_b_g = coeffs[:, 0,
                                                        ...], coeffs[:, 1,
                                                                     ...], coeffs[:,
                                                                                  2,
                                                                                  ...]  # each NKHW
            x_reg = means if self._self_auto_reg else x_raw
            means = torch.stack(
                (means_r, means_g + coeffs_g_r * x_reg[:, 0, ...],
                 means_b + coeffs_b_r * x_reg[:, 0, ...] +
                 coeffs_b_g * x_reg[:, 1, ...]),
                dim=1)  # NCKHW again

            if self._means_oracle:
                mse_pre = F.mse_loss(means, x_raw)
                diff = (x_raw - means).detach()
                means += self._means_oracle * diff
                mse_cur = F.mse_loss(means, x_raw)
                self.summarizer.register_scalars(
                    'val',
                    {f'dmll/0/oracle_mse_impact': lambda: mse_cur - mse_pre})

            # TODO: will not work for RGB baseline
            self.summarizer.register_scalars(
                'train', {
                    f'dmll/0/coeffs_{c}':
                    lambda c=c: coeffs[:, c, ...].detach().mean()
                    for c in range(C)
                })

        x = NormalizedTensor(x_raw, x.L)
        return x, logit_probs, means, log_scales
Exemplo n.º 4
0
    def optimize(self,
                 res: NormalizedTensor,
                 network_out: prob_clf.NetworkOutput) -> Tuple:
        if VERBOSE_TAU:
            cuda_timer.sync()
            start = time.time()

        with torch.enable_grad():
            network_out_ss = prob_clf.map_over(
                network_out, lambda f: f.detach()[..., ::self._subsampling, ::self._subsampling])

            # Subsample residual.
            res_ss = NormalizedTensor(
                res.t.detach()[..., ::self._subsampling, ::self._subsampling],
                res.L,
                res.centered)

            taus = self._get_taus(network_out_ss)

            optim_cls = {
                'SGD': torch.optim.SGD,
                'Adam': torch.optim.Adam,
                'RMSprop': torch.optim.RMSprop}[self._optim_cls]

            optim = optim_cls(taus.values(), lr=self._lr, **self._optim_params)
            tau_overhead_bytes = (sum(np.prod(tau.shape) * 4 for tau in taus.values())  # 4 bytes per float32.
                                  if not self._ignore_overhead
                                  else 0)

            loss_prev = None
            diffs = collections.deque(maxlen=5)
            initial = None

            losses = [] if self._plot_loss else None

            for i in range(self._num_iter):
                for tau in taus.values():
                    if tau.grad is not None:
                        tau.grad.detach_()
                        tau.grad.zero_()
                        tau.grad = None

                # forward pass
                network_out_ss_tau = self._get_modified_network_out(network_out_ss, taus)
                nll = self.loss_dmol_rgb.forward(res_ss, network_out_ss_tau)

                loss = nll.mean()
                if self._plot_loss:
                    losses.append(loss.item())

                if initial is None:
                    initial = loss.item()

                if loss_prev is not None:
                    diff = loss_prev - loss.item()
                    printv(f'\ritr {i}: {loss.item():.3f} '
                           f'// {diff:.3e} '
                           f'// gain: {initial - loss.item():.5f}', end='', flush=True)
                    diffs.append(abs(diff))
                    if self._early_stop and (len(diffs) >= 5 and np.mean(diffs) < 1e-4):
                        printv('\ndone after', i)
                        break
                loss_prev = loss.item()
                loss.backward()
                optim.step()
                optim.zero_grad()

            if losses:
                print('\n\n***\n', losses, '\n***\n\n')
                self._summary.losses.append(losses)

        if VERBOSE_TAU:
            cuda_timer.sync()
            # noinspection PyUnboundLocalVariable
            diff = time.time() - start
            self._summary.add_time(diff)
            printv(f'time for tau optim: {diff}')

        self._summary.add_diff(np.mean(diffs))

        # Note that this is not the real gain, since it's sub-sampled.
        # Note that this does not take overhead into account!
        final_subsampled_gain = initial - loss.item()
        if final_subsampled_gain < 0:
            printv('*** Was for nothing...')
            self._summary.num_fails += 1
            nll = self.loss_dmol_rgb.forward(res, network_out)
            return nll, None
        else:
            self._summary.add_gain(final_subsampled_gain)
            for k, tau in taus.items():
                self._summary.add_param(f'taus_{k}', tau)
            nll = self.loss_dmol_rgb.forward(
                res, self._get_modified_network_out(network_out, taus))
            # nll is without the overhead of tau
            return nll, tau_overhead_bytes