Ejemplo n.º 1
0
    def elbo(self, x: torch.Tensor, p_dists: List, q_dists: List, lv_z: List,
             lv_g: List, lv_bg: List, pa_recon: List) -> Tuple:

        bs = x.size(0)

        p_global_all, p_pres_given_g_probs_reshaped, \
        p_where_given_g, p_depth_given_g, p_what_given_g, p_bg = p_dists

        q_global_all, q_pres_given_x_and_g_probs_reshaped, \
        q_where_given_x_and_g, q_depth_given_x_and_g, q_what_given_x_and_g, q_bg = q_dists

        y, y_nobg, alpha_map, bg = pa_recon

        if self.args.log.phase_nll:
            # (bs, dim, num_cell, num_cell)
            z_pres, _, z_depth, z_what, z_where_origin = lv_z
            # (bs * num_cell * num_cell, dim)
            z_pres_reshape = z_pres.permute(0, 2, 3,
                                            1).reshape(-1,
                                                       self.args.z.z_pres_dim)
            z_depth_reshape = z_depth.permute(0, 2, 3, 1).reshape(
                -1, self.args.z.z_depth_dim)
            z_what_reshape = z_what.permute(0, 2, 3,
                                            1).reshape(-1,
                                                       self.args.z.z_what_dim)
            z_where_origin_reshape = z_where_origin.permute(
                0, 2, 3, 1).reshape(-1, self.args.z.z_where_dim)
            # (bs, dim, 1, 1)
            z_bg = lv_bg[0]
            # (bs, step, dim, 1, 1)
            z_g = lv_g[0]
        else:
            z_pres, _, _, _, z_where_origin = lv_z

            z_pres_reshape = z_pres.permute(0, 2, 3,
                                            1).reshape(-1,
                                                       self.args.z.z_pres_dim)

        if self.args.train.p_pres_anneal_end_step != 0:
            self.aux_p_pres_probs = linear_schedule_tensor(
                self.args.train.global_step,
                self.args.train.p_pres_anneal_start_step,
                self.args.train.p_pres_anneal_end_step,
                self.args.train.p_pres_anneal_start_value,
                self.args.train.p_pres_anneal_end_value,
                self.aux_p_pres_probs.device)

        if self.args.train.aux_p_scale_anneal_end_step != 0:
            aux_p_scale_mean = linear_schedule_tensor(
                self.args.train.global_step,
                self.args.train.aux_p_scale_anneal_start_step,
                self.args.train.aux_p_scale_anneal_end_step,
                self.args.train.aux_p_scale_anneal_start_value,
                self.args.train.aux_p_scale_anneal_end_value,
                self.aux_p_where_mean.device)
            self.aux_p_where_mean[:, 0] = aux_p_scale_mean

        auxiliary_prior_z_pres_probs = self.aux_p_pres_probs[None][
            None, :].expand(bs * self.args.arch.num_cell**2, -1)

        aux_kl_pres = kl_divergence_bern_bern(
            q_pres_given_x_and_g_probs_reshaped, auxiliary_prior_z_pres_probs)
        aux_kl_where = kl_divergence(
            q_where_given_x_and_g,
            self.aux_p_where) * z_pres_reshape.clamp(min=1e-5)
        aux_kl_depth = kl_divergence(
            q_depth_given_x_and_g,
            self.aux_p_depth) * z_pres_reshape.clamp(min=1e-5)
        aux_kl_what = kl_divergence(
            q_what_given_x_and_g,
            self.aux_p_what) * z_pres_reshape.clamp(min=1e-5)

        kl_pres = kl_divergence_bern_bern(q_pres_given_x_and_g_probs_reshaped,
                                          p_pres_given_g_probs_reshaped)

        kl_where = kl_divergence(q_where_given_x_and_g, p_where_given_g)
        kl_depth = kl_divergence(q_depth_given_x_and_g, p_depth_given_g)
        kl_what = kl_divergence(q_what_given_x_and_g, p_what_given_g)

        kl_global_all = kl_divergence(q_global_all, p_global_all)

        if self.args.arch.phase_background:
            kl_bg = kl_divergence(q_bg, p_bg)
            aux_kl_bg = kl_divergence(q_bg, self.aux_p_bg)
        else:
            kl_bg = self.background.new_zeros(bs, 1)
            aux_kl_bg = self.background.new_zeros(bs, 1)

        log_like = Normal(y, self.args.const.likelihood_sigma).log_prob(x)

        log_imp_list = []
        if self.args.log.phase_nll:
            log_pres_prior = z_pres_reshape * torch.log(p_pres_given_g_probs_reshaped + self.args.const.eps) + \
                             (1 - z_pres_reshape) * torch.log(1 - p_pres_given_g_probs_reshaped + self.args.const.eps)
            log_pres_pos = z_pres_reshape * torch.log(q_pres_given_x_and_g_probs_reshaped + self.args.const.eps) + \
                           (1 - z_pres_reshape) * torch.log(
                1 - q_pres_given_x_and_g_probs_reshaped + self.args.const.eps)

            log_imp_pres = log_pres_prior - log_pres_pos

            log_imp_depth = p_depth_given_g.log_prob(z_depth_reshape) - \
                            q_depth_given_x_and_g.log_prob(z_depth_reshape)

            log_imp_what = p_what_given_g.log_prob(z_what_reshape) - \
                           q_what_given_x_and_g.log_prob(z_what_reshape)

            log_imp_where = p_where_given_g.log_prob(z_where_origin_reshape) - \
                            q_where_given_x_and_g.log_prob(z_where_origin_reshape)

            if self.args.arch.phase_background:
                log_imp_bg = p_bg.log_prob(z_bg) - q_bg.log_prob(z_bg)
            else:
                log_imp_bg = x.new_zeros(bs, 1)

            log_imp_g = p_global_all.log_prob(z_g) - q_global_all.log_prob(z_g)

            log_imp_list = [
                log_imp_pres.view(bs, self.args.arch.num_cell,
                                  self.args.arch.num_cell,
                                  -1).flatten(start_dim=1).sum(1),
                log_imp_depth.view(bs, self.args.arch.num_cell,
                                   self.args.arch.num_cell,
                                   -1).flatten(start_dim=1).sum(1),
                log_imp_what.view(bs, self.args.arch.num_cell,
                                  self.args.arch.num_cell,
                                  -1).flatten(start_dim=1).sum(1),
                log_imp_where.view(bs, self.args.arch.num_cell,
                                   self.args.arch.num_cell,
                                   -1).flatten(start_dim=1).sum(1),
                log_imp_bg.flatten(start_dim=1).sum(1),
                log_imp_g.flatten(start_dim=1).sum(1),
            ]

        return log_like.flatten(start_dim=1).sum(1), \
               [
                   aux_kl_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(
                       -1),
                   aux_kl_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(
                       -1),
                   aux_kl_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(
                       -1),
                   aux_kl_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(
                       -1),
                   aux_kl_bg.flatten(start_dim=1).sum(-1),
                   kl_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1),
                   kl_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1),
                   kl_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1),
                   kl_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1),
                   kl_global_all.flatten(start_dim=2).sum(-1),
                   kl_bg.flatten(start_dim=1).sum(-1)
               ], log_imp_list