Exemple #1
0
    def _4_likelihood(self, X, nu, phi_var, phi):
        """
        @param X: (N, D)
        @param nu: (N, K)
        @param phi_var: (K, D)
        @param phi: (K, D)
        @return: ()

        Computes Likelihood: E_q(Z),q(A) [logp(X_n|Z_n,A,sigma_n^2 I)]
        Same as Finite Approach
        """
        N, _ = X.shape
        K, D = self.K, self.D  # for notational simplicity
        ret = 0
        constant = -0.5 * D * (self.sigma_n.log() + LOG_2PI)

        # we use the Concrete / Gumbel-softmax approximation
        Z = RelaxedBernoulli(temperature=self.T, probs=nu).rsample()

        # these terms are essentially the same, nu gets replaced by Z
        first_term = X.pow(2).sum()
        second_term = (-2 * (Z.view(N, K, 1) * phi.view(1, K, D)) *
                       X.view(N, 1, D)).sum()

        # this is Z^TE[A^TA]Z
        third_term = torch.diag(Z @ \
                                (phi @ phi.transpose(0, 1) + (phi_var.sum(1) * torch.eye(K))) @ \
                                Z.transpose(0, 1)).sum()

        nonconstant = (-0.5/(self.sigma_n**2)) * \
            (first_term + second_term + third_term)

        return constant + nonconstant
Exemple #2
0
    def _sample_bipartite(self, u_c: Tensor, u_t: Tensor) -> Tensor:
        """Samples bipartite: p(A|U_R, U_M).

        Args:
            u_c (torch.Tensor): u input for context, size `(b, n, u_dim)`.
            u_t (torch.Tensor): u input for target, size `(b, m, u_dim)`.

        Returns:
            bipartite (torch.Tensor): Bipartite graph, size `(b, m, n)`.
        """

        # Indices for pairs (u_t_i, u_c_j)
        b, n, _ = u_c.size()
        m = u_t.size(1)
        indices = torch.tensor(list(product(range(m), range(n)))).t()

        # Latent pairs (b, num_pairs, u_dim)
        pair_0 = u_t[:, indices[0]]
        pair_1 = u_c[:, indices[1]]

        # Compute logits for each pair
        logp = -0.5 * ((pair_0 - pair_1)**2).sum(dim=-1) / self.scale.exp()
        logits = logitexp(logp)

        # Sample graph from bernoulli dist (b, num_pairs)
        dist = RelaxedBernoulli(logits=logits, temperature=self.temperature)
        p_edges = dist.rsample()

        # Embed values
        bipartite = u_c.new_zeros((b, m, n))
        bipartite[:, indices[0], indices[1]] = p_edges

        return bipartite
Exemple #3
0
def sample():
    probabilities_dist = RelaxedBernoulli(temperature, probabilities)
    sample_probabilities = probabilities_dist.rsample()
    sample_probabilities = sample_probabilities.clamp(0.0, 1.0)
    sample_probabilities_index = sample_probabilities >= 0.5
    sample_probabilities = sample_probabilities_index.float(
    ) - sample_probabilities.detach() + sample_probabilities
    return sample_probabilities, sample_probabilities_index
    def forward(self, inputs, rnn_hxs, masks):
        x = inputs
        m_soft = RelaxedBernoulli(1.0, logits=self.input_attention).sample()
        m_hard = 0.5 * (torch.sign(m_soft - 0.5) + 1)
        mask = m_hard - m_soft.detach() + m_soft
        x = mask * x

        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs
Exemple #5
0
    def __init__(self):
        super().__init__()

        self.encoder = ModuleCompose(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            F.relu,
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
        )

        self.decoder = ModuleCompose(
            ConvPixelShuffle(64, 32, upscale_factor=2),
            F.relu,
            ConvPixelShuffle(32, 1, upscale_factor=2),
            lambda x: x[:, 0],
        )

        # Alternatives:
        # - RelaxedBernoulli - maybe doesn't work?
        # - RelaxedOneHotCategorical
        # - RelaxedOneHotCategorical * Codebook

        self.image = ModuleCompose(
            self.encoder,
            lambda logits: RelaxedBernoulli(
                temperature=0.5,
                logits=logits,
            ).rsample(),
            self.decoder,
        )
Exemple #6
0
 def get_mask(self, batch_size=None) -> torch.Tensor:
     size = (batch_size, 1, 1)
     if self.training:
         return RelaxedBernoulli(self.temperature,
                                 self.probability).rsample(size)
     else:
         return Bernoulli(self.probability).sample(size)
Exemple #7
0
def gumbel_softmax_bit_vector_sample(logits: torch.Tensor,
                                     temperature: float = 1.0,
                                     straight_through: bool = False):
    """Samples from a Gumbel-Sotmax/Concrete of independent Bernoulli distributions.
    More details in:
    - Gumbel-Softmax: https://arxiv.org/abs/1611.01144
    - Concrete distribution: https://arxiv.org/abs/1611.00712

    Arguments:
        logits {torch.Tensor} -- tensor of logits, the output of an inference network.
            Size: [batch_size, n_bits]

    Keyword Arguments:
        temperature {float} -- temperature of the softmax relaxation. The lower the
            temperature (-->0), the closer the sample is to discrete samples.
            (default: {1.0})
        straight_through {bool} -- Whether to use the straight-through estimator.
            (default: {False})

    Returns:
        torch.Tensor -- the relaxed sample.
            Size: [batch_size, n_bits]
    """

    sample = RelaxedBernoulli(logits=logits, temperature=temperature).rsample()

    if straight_through:
        hard_sample = (logits > 0).to(torch.float)
        sample = sample + (hard_sample - sample).detach()

    return sample
    def forward(self, inputs, rnn_hxs, masks):
        x = inputs
        probs = F.softmax(self.input_attention, dim=0)
        probs = probs / torch.max(probs)
        m_soft = RelaxedBernoulli(1.0, probs=probs).sample()
        attn_log_probs = RelaxedBernoulli(1.0, probs=probs).log_prob(m_soft)
        mask = 0.5 * (torch.sign(m_soft - 0.5) + 1)
        x = mask * x

        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        return self.critic_linear(
            hidden_critic), hidden_actor, rnn_hxs, attn_log_probs
Exemple #9
0
    def _sample_dag(self, u_c: Tensor) -> Tensor:
        """Samples DAG from context data: p(G|U_R).

        Args:
            u_c (torch.Tensor): u input for context, size `(b, n, u_dim)`.

        Returns:
            graph (torch.Tensor): Sampled DAG, size `(b, n, n)`.
        """

        # Data size
        b, n, _ = u_c.size()

        # Ordering by log CDF
        log_cdf = (0.5 * (u_c / 2**0.5).erf() + 0.5).log().sum(dim=-1)
        u_c_sorted, sort_idx = log_cdf.sort()

        # Indices of upper triangular adjacency matrix for DAG
        indices = torch.triu_indices(n, n, offset=1)

        # Latent pairs (b, num_pairs)
        pair_0 = u_c_sorted[:, indices[0]]
        pair_1 = u_c_sorted[:, indices[1]]

        # Compute logits for each pair
        logp = -0.5 * (pair_0 - pair_1)**2 / self.scale.exp()
        logits = logitexp(logp)

        # Sample graph from bernoulli dist (b, num_pairs)
        dist = RelaxedBernoulli(logits=logits, temperature=self.temperature)
        sorted_graph = dist.rsample()

        # Embed upper triangular to adjancency matrix
        graph = u_c.new_zeros((b, n, n))
        graph[:, indices[0], indices[1]] = sorted_graph

        # Unsort index of DAG to data order
        col_idx = torch.argsort(sort_idx)
        col_idx = col_idx.unsqueeze(1).repeat(1, n, 1)

        # Swap to unsort: 1. columns, 2. indices as columns
        graph = torch.gather(graph, -1, col_idx)
        graph = torch.gather(graph.permute(0, 2, 1), -1, col_idx)
        graph = graph.permute(0, 2, 1)

        return graph
Exemple #10
0
    def forward(self, ss: List, phase_use_mode: bool = False) -> Tuple:

        p_pres_logits, p_where_mean, p_where_std, p_depth_mean, \
        p_depth_std, p_what_mean, p_what_std = ss

        if phase_use_mode:
            z_pres = (p_pres_logits > 0).float()
        else:
            z_pres = RelaxedBernoulli(logits=p_pres_logits, temperature=self.args.train.tau_pres).rsample()

        # z_where_scale, z_where_shift: (bs, dim, num_cell, num_cell)
        if phase_use_mode:
            z_where_scale, z_where_shift = p_where_mean.chunk(2, 1)
        else:
            z_where_scale, z_where_shift = \
                Normal(p_where_mean, p_where_std).rsample().chunk(2, 1)

        # z_where_origin: (bs, dim, num_cell, num_cell)
        z_where_origin = \
            torch.cat([z_where_scale.detach(), z_where_shift.detach()], dim=1)

        z_where_shift = \
            (2. / self.args.arch.num_cell) * \
            (self.offset + 0.5 + torch.tanh(z_where_shift)) - 1.

        scale, ratio = z_where_scale.chunk(2, 1)
        scale = scale.sigmoid()
        ratio = torch.exp(ratio)
        ratio_sqrt = ratio.sqrt()
        z_where_scale = torch.cat([scale / ratio_sqrt, scale * ratio_sqrt], dim=1)
        # z_where: (bs, dim, num_cell, num_cell)
        z_where = torch.cat([z_where_scale, z_where_shift], dim=1)

        if phase_use_mode:
            z_depth = p_depth_mean
            z_what = p_what_mean
        else:
            z_depth = Normal(p_depth_mean, p_depth_std).rsample()
            z_what = Normal(p_what_mean, p_what_std).rsample()

        z_what_reshape = z_what.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_what_dim). \
            view(-1, self.args.z.z_what_dim, 1, 1)

        if self.args.data.inp_channel == 1 or not self.args.arch.phase_overlap:
            o = self.z_what_decoder_net(z_what_reshape)
            o = o.sigmoid()
            a = o.new_ones(o.size())
        elif self.args.arch.phase_overlap:
            o, a = self.z_what_decoder_net(z_what_reshape).split([self.args.data.inp_channel, 1], dim=1)
            o, a = o.sigmoid(), a.sigmoid()
        else:
            raise NotImplemented

        lv = [z_pres, z_where, z_depth, z_what, z_where_origin]
        pa = [o, a]

        return pa, lv
Exemple #11
0
def gumbel_softmax_bit_vector_sample(logits: torch.Tensor,
                                     temperature: float = 1.0,
                                     straight_through: bool = False):

    sample = RelaxedBernoulli(logits=logits, temperature=temperature).rsample()

    if straight_through:
        hard_sample = (logits > 0).to(torch.float)
        sample = sample + (hard_sample - sample).detach()

    return sample
Exemple #12
0
def rsample_gumbel_softmax(
    distr: Distribution,
    n: int,
    temperature: torch.Tensor,
    straight_through: bool = False,
) -> torch.Tensor:
    if isinstance(distr, (Categorical, OneHotCategorical)):
        if straight_through:
            gumbel_distr = RelaxedOneHotCategoricalStraightThrough(
                temperature, probs=distr.probs)
        else:
            gumbel_distr = RelaxedOneHotCategorical(temperature,
                                                    probs=distr.probs)
    elif isinstance(distr, Bernoulli):
        if straight_through:
            gumbel_distr = RelaxedBernoulliStraightThrough(temperature,
                                                           probs=distr.probs)
        else:
            gumbel_distr = RelaxedBernoulli(temperature, probs=distr.probs)
    else:
        raise ValueError("Using Gumbel Softmax with non-discrete distribution")
    return gumbel_distr.rsample((n, ))
    def forward(self,
                data_batch: dict,
                temperature=1.0,
                depth_scale=10.0,
                fast=False,
                **kwargs):
        pd_dict = dict()
        image = data_batch['image']
        b, c0, h0, w0 = image.size()
        A = self.num_anchors_per_cell

        # ---------------------------------------------------------------------------- #
        # CNN encodes feature maps
        # ---------------------------------------------------------------------------- #
        fg_feature = self.fg_encoder(image)
        fg_feature = self.rpn(fg_feature)
        _, c1, h1, w1 = fg_feature.size()

        # ---------------------------------------------------------------------------- #
        # Relaxed Bernoulli z_pres
        # ---------------------------------------------------------------------------- #
        latent_pres = self.latent_pres(fg_feature)  # (b, A, h1, w1)
        z_pres_p = torch.sigmoid(latent_pres)
        z_pres_p = z_pres_p.reshape(b, -1)  # (b, A * h1 * w1)
        # In order to avoid gradient explosion at 0 and 1, clip
        z_pres_p = z_pres_p.clamp(min=self._eps, max=1.0 - self._eps)
        z_pres_post = RelaxedBernoulli(z_pres_p.new_tensor(temperature),
                                       probs=z_pres_p)
        if self.training:
            z_pres = z_pres_post.rsample()
        else:
            z_pres = z_pres_p

        pd_dict['z_pres'] = z_pres  # (b, A * h1 * w1)
        pd_dict['z_pres_p'] = z_pres_p
        pd_dict['z_pres_post'] = z_pres_post

        # ---------------------------------------------------------------------------- #
        # Gaussian z_depth
        # ---------------------------------------------------------------------------- #
        latent_depth = self.latent_depth(fg_feature)  # (b, A * 2, h1, w1)
        z_depth_loc = latent_depth.narrow(1, 0, A)
        z_depth_scale = F.softplus(latent_depth.narrow(1, A, A))
        z_depth_post = Normal(z_depth_loc.reshape(b, -1),
                              z_depth_scale.reshape(b, -1))
        if self.training:
            z_depth = z_depth_post.rsample()  # (b, A * h1 * w1)
        else:
            z_depth = z_depth_loc

        pd_dict['z_depth_post'] = z_depth_post  # (b, A * h1 * w1)
        pd_dict['z_depth'] = z_depth  # (b, A * h1 * w1)

        # ---------------------------------------------------------------------------- #
        # Gaussian z_where
        # (offset_x, offset_y, scale_x, scale_y)
        # ---------------------------------------------------------------------------- #
        latent_where = self.latent_where(fg_feature)  # (b, A * 8, h1, w1)
        latent_where = latent_where.reshape(b, A, 8, h1,
                                            w1)  # (b, A, 8, h1, w1)
        latent_where = latent_where.permute(
            0, 1, 3, 4, 2).contiguous()  # (b, A, h1, w1, 8)
        latent_where = latent_where.reshape(b, A * h1 * w1, 8)
        z_where_loc = latent_where.narrow(-1, 0, 4)
        z_where_scale = F.softplus(latent_where.narrow(-1, 4, 4))
        z_where_post = Normal(z_where_loc, z_where_scale)
        if self.training:
            z_where = z_where_post.rsample()  # (b, A * h1 * w1, 4)
        else:
            z_where = z_where_loc

        pd_dict['z_where_post'] = z_where_post  # (b, A * h1 * w1, 4)

        # ---------------------------------------------------------------------------- #
        # Decode z_where to boxes
        # ---------------------------------------------------------------------------- #
        # (A * h1 * w1, 4)
        anchors = grid_anchors(self.cell_anchors, [h1, w1],
                               [int(h0 / h1), int(w0 / w1)])
        # (b, A * h1 * w1, 4)
        boxes = decode_boxes(anchors,
                             z_where,
                             clip_delta=self.clip_delta,
                             image_shape=(h0,
                                          w0) if self.clip_to_image else None)

        pd_dict['boxes'] = boxes
        pd_dict['grid_size'] = (h1, w1)

        # ---------------------------------------------------------------------------- #
        # Normalize boxes
        # Note that spatial transform assumes [-1, 1] for coordinates
        # ---------------------------------------------------------------------------- #
        x_min, y_min, x_max, y_max = torch.split(boxes, 1, dim=-1)
        # (b, A * h1 * w1, 4)
        normalized_boxes = torch.cat(
            [x_min / w0, y_min / h0, x_max / w0, y_max / h0], dim=-1)
        # convert xyxy to xywh
        normalized_boxes = boxes_xyxy2xywh(normalized_boxes)

        if fast:
            pd_dict['normalized_boxes'] = normalized_boxes
            return pd_dict

        # ---------------------------------------------------------------------------- #
        # Gaussian z_what
        # ---------------------------------------------------------------------------- #
        # Crop glimpses, (b * A * h1 * w1, c, h2, w2)
        glimpses = image_to_glimpse(image,
                                    normalized_boxes,
                                    glimpse_shape=self.glimpse_shape)
        # Gaussian z_what
        glimpses_feature = self.glimpse_encoder(glimpses)
        latent_what = self.latent_what(glimpses_feature.flatten(1))
        z_what_loc = latent_what[:, 0:self.z_what_size]
        z_what_scale = F.softplus(latent_what[:, self.z_what_size:])
        z_what_post = Normal(z_what_loc, z_what_scale)
        if self.training:
            z_what = z_what_post.rsample()
        else:
            z_what = z_what_loc

        pd_dict['z_what_post'] = z_what_post  # (b * A * h1 * w1, z_what_size)
        pd_dict['glimpses'] = glimpses.reshape(b, -1, c0,
                                               self.glimpse_shape[0],
                                               self.glimpse_shape[1])
        pd_dict['glimpses_feature'] = glimpses_feature.reshape(
            b, A * h1 * w1, -1)

        # ---------------------------------------------------------------------------- #
        # Decode z_what
        # ---------------------------------------------------------------------------- #
        # (b * A * h1 * w1, (c+1), h2, w2)
        glimpses_recon = self.glimpse_decoder(
            z_what.unsqueeze(-1).unsqueeze(-1))
        glimpses_recon = torch.sigmoid(glimpses_recon)

        glimpses_recon_reshape = glimpses_recon.reshape(
            b, -1, c0 + 1, self.glimpse_shape[0], self.glimpse_shape[1])
        pd_dict['glimpse_rgb'] = glimpses_recon_reshape[:, :, :-1]
        pd_dict['glimpse_alpha'] = glimpses_recon_reshape[:, :, -1:]

        # ---------------------------------------------------------------------------- #
        # Foreground
        # ---------------------------------------------------------------------------- #
        # (b * A * h1 * w1, c0 + 1, h0, w0)
        fg_rgba = glimpse_to_image(glimpses_recon,
                                   normalized_boxes,
                                   image_shape=(h0, w0))
        # (b, A * h1 * w1, c + 1, h, w)
        fg_rgba = fg_rgba.reshape(b, -1, c0 + 1, h0, w0)
        # (b, A * h1 * w1, 1, 1, 1)
        z_pres_reshape = z_pres.reshape(b, -1, 1, 1, 1)

        # Note that first c0 channels are rgb, and the last one is alpha.
        fg_rgb = fg_rgba[:, :, :-1]  # (b, A * h1 * w1, c0, h0, w0)
        fg_alpha = fg_rgba[:, :, -1:]  # (b, A * h1 * w1, 1, h0, w0)
        # Use foreground objects only
        fg_alpha_valid = fg_alpha * z_pres_reshape
        z_depth_reshape = z_depth.reshape(b, -1, 1, 1, 1)
        fg_weight = torch.softmax(fg_alpha_valid * depth_scale *
                                  torch.sigmoid(z_depth_reshape),
                                  dim=1)
        fg_mask_all = fg_alpha_valid * fg_weight
        fg_recon = (fg_rgb * fg_mask_all).sum(1)
        fg_mask = fg_mask_all.sum(1)

        pd_dict['fg_recon'] = fg_recon
        pd_dict['fg_mask'] = fg_mask

        # ---------------------------------------------------------------------------- #
        # Background
        # ---------------------------------------------------------------------------- #
        bg_feature = self.bg_encoder(image)
        bg_recon = self.bg_decoder(bg_feature)
        bg_recon = torch.sigmoid(bg_recon)
        bg_recon = bg_recon.reshape(b, c0, h0, w0)
        pd_dict['bg_recon'] = bg_recon

        return pd_dict
Exemple #14
0
    def forward(self, x, state=None):
        image = x
        canvas = torch.zeros_like(x.data)

        x, context = self.memory.init(image)

        c_data = context.data
        query = F.relu6(self.qdown(c_data))

        mu = []
        var = []
        stages = []
        masks = []

        for i in range(self.count):
            x, inverse = self.memory.glimpse(x, image)
            out = self.memory(query)
            o_mu = self.mu(out)
            o_var = self.var(out)
            mu.append(o_mu)
            var.append(o_var)
            out = self.sample(o_mu, o_var)
            out = F.relu(self.sup(out))
            out = self.decoder(out)

            inverse = inverse.view(out.size(0), 2, 3)

            grid = F.affine_grid(inverse, torch.Size([canvas.size(0), canvas.size(1) + 1, canvas.size(2), canvas.size(3)]))
            out = F.grid_sample(out.sigmoid(), grid)

            p = out[:, 0, :, :].unsqueeze(1)
            masks.append(p)

            out = out[:, 1:, :, :]

            dist = RelaxedBernoulli(torch.tensor([2.0]).to(p.device), probs=p)
            p = dist.rsample()

            canvas = canvas * (1 - p)
            out = out * p
            canvas += out

            if self.output_stages:
                square = self.square.clone().repeat(out.size(0), 1, 1, 1)
                square = F.grid_sample(square, grid)

                stage_image = out.data.clone()
                stage_image = stage_image + square
                stage_image = stage_image.clamp(0, 1)
                stages.append(stage_image.unsqueeze(1))

        if state is not None:
            state[torchbearer.Y_TRUE] = image
            state[MU] = torch.stack(mu, dim=1)
            state[LOGVAR] = torch.stack(var, dim=1)
            state[MASKED_TARGET] = state[torchbearer.Y_TRUE].detach() * p.detach()
            if self.output_stages:
                stages.append(image.clone().unsqueeze(1))
                state[STAGES] = torch.cat(stages, dim=1)

        return canvas
Exemple #15
0
 def get_probability_mask(self, batch_size):
     size = (batch_size, 1, 1)
     return RelaxedBernoulli(self.temperature,
                             self.probability).rsample(size)
Exemple #16
0
def gumbel_sigmoid(input: torch.Tensor, temp: float) -> torch.Tensor:
    """ gumbel sigmoid function
    """
    return RelaxedBernoulli(temp, probs=input.sigmoid()).rsample()
Exemple #17
0
def gumbel_sigmoid(input, temp):
    return RelaxedBernoulli(temp, probs=input).rsample()
Exemple #18
0
    def propagate(self, x, state_post_prev, state_prior_prev, z_prev, bg):
        """
        Do propagation, conditioned on everything.
        Args:
            x: (B, 3, H, W), img
            (h, c), (h, c): each (B, N, D)
            z_prev:
                z_pres: (B, N, 1)
                z_depth: (B, N, 1)
                z_where: (B, N, 4)
                z_what: (B, N, D)

        Returns:
            h_post, c_post: (B, N, D)
            h_prior, c_prior: (B, N, D)
            z:
                z_pres: (B, N, 1)
                z_depth: (B, N, 1)
                z_where: (B, N, 4)
                z_what: (B, N, D)
            kl:
                kl_pres: (B,)
                kl_what: (B,)
                kl_where: (B,)
                kl_depth: (B,)
            proposal_region: (B, N, 4)

        """
        z_pres_prev, z_depth_prev, z_where_prev, z_what_prev = z_prev
        B, N, _ = z_pres_prev.size()

        if N == 0:
            # No object is propagated
            return state_post_prev, state_prior_prev, z_prev, (0.0, 0.0, 0.0,
                                                               0.0), z_prev[2]

        h_post, c_post = state_post_prev
        h_prior, c_prior = state_prior_prev

        # Predict proposal locations, (B, N, 2)
        proposal_offset = self.pred_proposal(h_post)
        proposal = torch.zeros_like(z_where_prev)
        # Update size only
        proposal[..., 2:] = z_where_prev[..., 2:]
        proposal[
            ..., :2] = z_where_prev[..., :2] + ARCH.PROPOSAL_UPDATE_MIN + (
                ARCH.PROPOSAL_UPDATE_MAX -
                ARCH.PROPOSAL_UPDATE_MIN) * torch.sigmoid(proposal_offset)

        # Get proposal glimpses
        # (B*N, 3, H, W)
        x_repeat = torch.repeat_interleave(x[:, :3], N, dim=0)

        # (B*N, 3, H, W)
        proposal_glimpses = spatial_transform(x_repeat,
                                              proposal.view(B * N, 4),
                                              out_dims=(B * N, 3,
                                                        *ARCH.GLIMPSE_SHAPE))
        # (B, N, 3, H, W)
        proposal_glimpses = proposal_glimpses.view(B, N, 3,
                                                   *ARCH.GLIMPSE_SHAPE)
        # (B, N, D)
        proposal_enc = self.proposal_encoder(proposal_glimpses)
        # (B, N, D)
        # This will be used to condition everything
        enc = torch.cat([proposal_enc, h_post], dim=-1)

        # (B, N, D)
        (z_pres_prob, z_depth_offset_loc, z_depth_offset_scale,
         z_where_offset_loc, z_where_offset_scale, z_what_offset_loc,
         z_what_offset_scale) = self.pres_depth_where_what_post_prop(enc)

        # Sampling
        z_pres_post = RelaxedBernoulli(self.tau, probs=z_pres_prob)
        z_pres = z_pres_post.rsample()
        z_pres = z_pres_prev * z_pres

        z_where_post = Normal(z_where_offset_loc, z_where_offset_scale)
        z_where_offset = z_where_post.rsample()
        z_where = torch.zeros_like(z_where_prev)
        # Scale
        z_where[..., :2] = z_where_prev[
            ..., :2] + ARCH.Z_SCALE_UPDATE_SCALE * torch.tanh(
                z_where_offset[..., :2])
        # Shift
        z_where[..., 2:] = z_where_prev[
            ..., 2:] + ARCH.Z_SHIFT_UPDATE_SCALE * torch.tanh(
                z_where_offset[..., 2:])

        z_depth_post = Normal(z_depth_offset_loc, z_depth_offset_scale)
        z_depth_offset = z_depth_post.rsample()
        z_depth = z_depth_prev + ARCH.Z_DEPTH_UPDATE_SCALE + z_depth_offset

        z_what_post = Normal(z_what_offset_loc, z_what_offset_scale)
        z_what_offset = z_what_post.rsample()
        z_what = z_what_prev + ARCH.Z_WHAT_UPDATE_SCALE * torch.tanh(
            z_what_offset)
        z = (z_pres, z_depth, z_where, z_what)

        # Update states
        state_post = self.temporal_encode(state_post_prev,
                                          z,
                                          bg,
                                          prior_or_post='post')
        state_prior = self.temporal_encode(state_prior_prev,
                                           z,
                                           bg,
                                           prior_or_post='prior')

        # Other priors
        (z_pres_prob, z_depth_offset_loc, z_depth_offset_scale,
         z_where_offset_loc, z_where_offset_scale, z_what_offset_loc,
         z_what_offset_scale) = self.pres_depth_where_what_prior_prop(h_prior)

        z_depth_prior = Normal(z_depth_offset_loc, z_depth_offset_scale)
        z_where_prior = Normal(z_where_offset_loc, z_where_offset_scale)
        z_what_prior = Normal(z_what_offset_loc, z_what_offset_scale)

        # This is not kl divergence. This is an auxialiary loss
        kl_pres = kl_divergence_bern_bern(
            z_pres_prob, torch.full_like(z_pres_prob, self.z_pres_prior_prob))
        kl_depth = kl_divergence(z_depth_post, z_depth_prior)
        kl_depth *= z_pres
        kl_where = kl_divergence(z_where_post, z_where_prior)
        kl_where *= z_pres
        kl_what = kl_divergence(z_what_post, z_what_prior)
        kl_what *= z_pres

        # Reduced to (B,)

        # Again, this is not really kl
        kl_pres = kl_pres.flatten(start_dim=1).sum(-1)
        kl_depth = kl_depth.flatten(start_dim=1).sum(-1)
        kl_where = kl_where.flatten(start_dim=1).sum(-1)
        kl_what = kl_what.flatten(start_dim=1).sum(-1)

        assert kl_pres.size(0) == B
        kl = (kl_pres, kl_depth, kl_where, kl_what)

        return state_post, state_prior, z, kl, proposal
Exemple #19
0
    def propagate_gen(self, state_prev, z_prev, bg, sample=False):
        """
        
        Args:
            h_prev, c_prev: (B, N, D)
            z_prev:
                z_pres_prev: (B, N, 1)
                z_depth_prev: (B, N, 1)
                z_where_prev: (B, N, 4)
                z_what_prev: (B, N, D)
        Returns:
            h, c: (B, N, D)
            z:
                z_pres: (B, N, 1)
                z_depth: (B, N, 1)
                z_where: (B, N, 4)
                z_what: (B, N, D)
        """
        h_prev, c_prev = state_prev
        z_pres_prev, z_depth_prev, z_where_prev, z_what_prev = z_prev

        # (B, N, D)

        # TODO: z_pres_prior is not learned
        # All (B, N, D)
        (z_pres_prob, z_depth_offset_loc, z_depth_offset_scale,
         z_where_offset_loc, z_where_offset_scale, z_what_offset_loc,
         z_what_offset_scale) = self.pres_depth_where_what_prior_prop(h_prev)

        z_pres_prior = RelaxedBernoulli(temperature=self.tau,
                                        probs=z_pres_prob)
        z_pres = z_pres_prior.sample()
        z_pres = (z_pres > 0.5).float()
        z_pres = torch.ones_like(z_pres)
        z_pres = z_pres_prev * z_pres

        z_where_prior = Normal(z_where_offset_loc, z_where_offset_scale)
        z_where_offset = z_where_prior.rsample(
        ) if sample else z_where_offset_loc
        z_where = torch.zeros_like(z_where_prev)
        # Scale
        z_where[..., :2] = z_where_prev[
            ..., :2] + ARCH.Z_SCALE_UPDATE_SCALE * torch.tanh(
                z_where_offset[..., :2])
        # Shift
        z_where[..., 2:] = z_where_prev[
            ..., 2:] + ARCH.Z_SHIFT_UPDATE_SCALE * torch.tanh(
                z_where_offset[..., 2:])

        z_depth_prior = Normal(z_depth_offset_loc, z_depth_offset_scale)
        z_depth_offset = z_depth_prior.rsample(
        ) if sample else z_depth_offset_loc
        z_depth = z_depth_prev + ARCH.Z_DEPTH_UPDATE_SCALE * z_depth_offset

        z_what_prior = Normal(z_what_offset_loc, z_what_offset_scale)
        z_what_offset = z_what_prior.rsample() if sample else z_what_offset_loc
        z_what = z_what_prev + ARCH.Z_WHAT_UPDATE_SCALE * torch.tanh(
            z_what_offset)

        z = (z_pres, z_depth, z_where, z_what)

        state = self.temporal_encode(state_prev, z, bg, prior_or_post='prior')

        return state, z
Exemple #20
0
    def discover(self, x, z_prop, bg, start_id=0):
        """
        Given current image and propagated objects, discover new objects
        Args:
            x: (B, D, H, W), current input image
            z_prop:
                z_pres_prop: (B, N, 1)
                z_depth_prop: (B, N, 1)
                z_where_prop: (B, N, 4)
                z_what_prop: (B, N, D)
            start_id: the id to start indexing

        Returns:
            (h_post, c_post): (B, N, D)
            (h_prior, c_prior): (B, N, D)
            z:
                z_pres: (B, N, 1)
                z_depth: (B, N, 1)
                z_where: (B, N, 4)
                z_what: (B, N, D)
            ids: (B, N)
            kl:
                kl_pres: (B,)
                kl_depth: (B,)
                kl_where: (B,)
                kl_what: (B,)
        )
        """
        B, *_ = x.size()

        # (B, D, G, G)
        x_enc = self.img_encoder(x)
        # For each discovery cell, we combine propagated objects weighted by distances
        # (B, D, G, G)
        prop_map = self.compute_prop_map(z_prop)
        # (B, D, G, G)
        enc = torch.cat([x_enc, prop_map], dim=1)

        (z_pres_post_prob, z_depth_post_loc, z_depth_post_scale,
         z_where_post_loc, z_where_post_scale, z_what_post_loc,
         z_what_post_scale) = self.pres_depth_where_what_post_disc(enc)

        # Compute posteriors. All (B, G*G, D)
        z_pres_post = RelaxedBernoulli(temperature=self.tau,
                                       probs=z_pres_post_prob)
        z_pres = z_pres_post.rsample()

        z_depth_post = Normal(z_depth_post_loc, z_depth_post_scale)
        z_depth = z_depth_post.rsample()

        z_where_post = Normal(z_where_post_loc, z_where_post_scale)
        z_where = z_where_post.rsample()
        z_where = self.z_where_relative_to_absolute(z_where)

        z_what_post = Normal(z_what_post_loc, z_what_post_scale)
        z_what = z_what_post.rsample()

        # Combine
        z = (z_pres, z_depth, z_where, z_what)

        # Rejection
        if ARCH.REJECTION:
            z = self.rejection(z, z_prop, ARCH.REJECTION_THRESHOLD)

        # Compute object ids
        # (B, G*G) + (B, 1)
        ids = torch.arange(ARCH.G**2, device=x_enc.device).expand(
            B, ARCH.G**2) + start_id[:, None]

        # Update temporal states
        state_post_prev = self.get_state_init(B, 'post')
        state_post = self.temporal_encode(state_post_prev,
                                          z,
                                          bg,
                                          prior_or_post='post')

        state_prior_prev = self.get_state_init(B, 'prior')
        state_prior = self.temporal_encode(state_prior_prev,
                                           z,
                                           bg,
                                           prior_or_post='prior')

        # All (B, G*G, D)
        # Conditional kl divergences
        kl_pres = kl_divergence_bern_bern(
            z_pres_post_prob,
            torch.full_like(z_pres_post_prob, self.z_pres_prior_prob))

        z_depth_prior, z_where_prior, z_what_prior = self.get_discovery_priors(
            x.device)
        # where prior, (B, G*G, 4)
        kl_where = kl_divergence(z_where_post, z_where_prior)
        kl_where = kl_where * z_pres

        # what prior (B, G*G, D)
        kl_what = kl_divergence(z_what_post, z_what_prior)
        kl_what = kl_what * z_pres

        # what prior (B, G*G, D)
        kl_depth = kl_divergence(z_depth_post, z_depth_prior)
        kl_depth = kl_depth * z_pres

        # Sum over non-batch dimensions
        kl_pres = kl_pres.flatten(start_dim=1).sum(1)
        kl_where = kl_where.flatten(start_dim=1).sum(1)
        kl_what = kl_what.flatten(start_dim=1).sum(1)
        kl_depth = kl_depth.flatten(start_dim=1).sum(1)
        kl = (kl_pres, kl_depth, kl_where, kl_what)

        return state_post, state_prior, z, ids, kl