def _4_likelihood(self, X, nu, phi_var, phi): """ @param X: (N, D) @param nu: (N, K) @param phi_var: (K, D) @param phi: (K, D) @return: () Computes Likelihood: E_q(Z),q(A) [logp(X_n|Z_n,A,sigma_n^2 I)] Same as Finite Approach """ N, _ = X.shape K, D = self.K, self.D # for notational simplicity ret = 0 constant = -0.5 * D * (self.sigma_n.log() + LOG_2PI) # we use the Concrete / Gumbel-softmax approximation Z = RelaxedBernoulli(temperature=self.T, probs=nu).rsample() # these terms are essentially the same, nu gets replaced by Z first_term = X.pow(2).sum() second_term = (-2 * (Z.view(N, K, 1) * phi.view(1, K, D)) * X.view(N, 1, D)).sum() # this is Z^TE[A^TA]Z third_term = torch.diag(Z @ \ (phi @ phi.transpose(0, 1) + (phi_var.sum(1) * torch.eye(K))) @ \ Z.transpose(0, 1)).sum() nonconstant = (-0.5/(self.sigma_n**2)) * \ (first_term + second_term + third_term) return constant + nonconstant
def _sample_bipartite(self, u_c: Tensor, u_t: Tensor) -> Tensor: """Samples bipartite: p(A|U_R, U_M). Args: u_c (torch.Tensor): u input for context, size `(b, n, u_dim)`. u_t (torch.Tensor): u input for target, size `(b, m, u_dim)`. Returns: bipartite (torch.Tensor): Bipartite graph, size `(b, m, n)`. """ # Indices for pairs (u_t_i, u_c_j) b, n, _ = u_c.size() m = u_t.size(1) indices = torch.tensor(list(product(range(m), range(n)))).t() # Latent pairs (b, num_pairs, u_dim) pair_0 = u_t[:, indices[0]] pair_1 = u_c[:, indices[1]] # Compute logits for each pair logp = -0.5 * ((pair_0 - pair_1)**2).sum(dim=-1) / self.scale.exp() logits = logitexp(logp) # Sample graph from bernoulli dist (b, num_pairs) dist = RelaxedBernoulli(logits=logits, temperature=self.temperature) p_edges = dist.rsample() # Embed values bipartite = u_c.new_zeros((b, m, n)) bipartite[:, indices[0], indices[1]] = p_edges return bipartite
def sample(): probabilities_dist = RelaxedBernoulli(temperature, probabilities) sample_probabilities = probabilities_dist.rsample() sample_probabilities = sample_probabilities.clamp(0.0, 1.0) sample_probabilities_index = sample_probabilities >= 0.5 sample_probabilities = sample_probabilities_index.float( ) - sample_probabilities.detach() + sample_probabilities return sample_probabilities, sample_probabilities_index
def forward(self, inputs, rnn_hxs, masks): x = inputs m_soft = RelaxedBernoulli(1.0, logits=self.input_attention).sample() m_hard = 0.5 * (torch.sign(m_soft - 0.5) + 1) mask = m_hard - m_soft.detach() + m_soft x = mask * x if self.is_recurrent: x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) hidden_critic = self.critic(x) hidden_actor = self.actor(x) return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs
def __init__(self): super().__init__() self.encoder = ModuleCompose( nn.Conv2d(1, 32, 3, stride=2, padding=1), F.relu, nn.Conv2d(32, 64, 3, stride=2, padding=1), ) self.decoder = ModuleCompose( ConvPixelShuffle(64, 32, upscale_factor=2), F.relu, ConvPixelShuffle(32, 1, upscale_factor=2), lambda x: x[:, 0], ) # Alternatives: # - RelaxedBernoulli - maybe doesn't work? # - RelaxedOneHotCategorical # - RelaxedOneHotCategorical * Codebook self.image = ModuleCompose( self.encoder, lambda logits: RelaxedBernoulli( temperature=0.5, logits=logits, ).rsample(), self.decoder, )
def get_mask(self, batch_size=None) -> torch.Tensor: size = (batch_size, 1, 1) if self.training: return RelaxedBernoulli(self.temperature, self.probability).rsample(size) else: return Bernoulli(self.probability).sample(size)
def gumbel_softmax_bit_vector_sample(logits: torch.Tensor, temperature: float = 1.0, straight_through: bool = False): """Samples from a Gumbel-Sotmax/Concrete of independent Bernoulli distributions. More details in: - Gumbel-Softmax: https://arxiv.org/abs/1611.01144 - Concrete distribution: https://arxiv.org/abs/1611.00712 Arguments: logits {torch.Tensor} -- tensor of logits, the output of an inference network. Size: [batch_size, n_bits] Keyword Arguments: temperature {float} -- temperature of the softmax relaxation. The lower the temperature (-->0), the closer the sample is to discrete samples. (default: {1.0}) straight_through {bool} -- Whether to use the straight-through estimator. (default: {False}) Returns: torch.Tensor -- the relaxed sample. Size: [batch_size, n_bits] """ sample = RelaxedBernoulli(logits=logits, temperature=temperature).rsample() if straight_through: hard_sample = (logits > 0).to(torch.float) sample = sample + (hard_sample - sample).detach() return sample
def forward(self, inputs, rnn_hxs, masks): x = inputs probs = F.softmax(self.input_attention, dim=0) probs = probs / torch.max(probs) m_soft = RelaxedBernoulli(1.0, probs=probs).sample() attn_log_probs = RelaxedBernoulli(1.0, probs=probs).log_prob(m_soft) mask = 0.5 * (torch.sign(m_soft - 0.5) + 1) x = mask * x if self.is_recurrent: x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) hidden_critic = self.critic(x) hidden_actor = self.actor(x) return self.critic_linear( hidden_critic), hidden_actor, rnn_hxs, attn_log_probs
def _sample_dag(self, u_c: Tensor) -> Tensor: """Samples DAG from context data: p(G|U_R). Args: u_c (torch.Tensor): u input for context, size `(b, n, u_dim)`. Returns: graph (torch.Tensor): Sampled DAG, size `(b, n, n)`. """ # Data size b, n, _ = u_c.size() # Ordering by log CDF log_cdf = (0.5 * (u_c / 2**0.5).erf() + 0.5).log().sum(dim=-1) u_c_sorted, sort_idx = log_cdf.sort() # Indices of upper triangular adjacency matrix for DAG indices = torch.triu_indices(n, n, offset=1) # Latent pairs (b, num_pairs) pair_0 = u_c_sorted[:, indices[0]] pair_1 = u_c_sorted[:, indices[1]] # Compute logits for each pair logp = -0.5 * (pair_0 - pair_1)**2 / self.scale.exp() logits = logitexp(logp) # Sample graph from bernoulli dist (b, num_pairs) dist = RelaxedBernoulli(logits=logits, temperature=self.temperature) sorted_graph = dist.rsample() # Embed upper triangular to adjancency matrix graph = u_c.new_zeros((b, n, n)) graph[:, indices[0], indices[1]] = sorted_graph # Unsort index of DAG to data order col_idx = torch.argsort(sort_idx) col_idx = col_idx.unsqueeze(1).repeat(1, n, 1) # Swap to unsort: 1. columns, 2. indices as columns graph = torch.gather(graph, -1, col_idx) graph = torch.gather(graph.permute(0, 2, 1), -1, col_idx) graph = graph.permute(0, 2, 1) return graph
def forward(self, ss: List, phase_use_mode: bool = False) -> Tuple: p_pres_logits, p_where_mean, p_where_std, p_depth_mean, \ p_depth_std, p_what_mean, p_what_std = ss if phase_use_mode: z_pres = (p_pres_logits > 0).float() else: z_pres = RelaxedBernoulli(logits=p_pres_logits, temperature=self.args.train.tau_pres).rsample() # z_where_scale, z_where_shift: (bs, dim, num_cell, num_cell) if phase_use_mode: z_where_scale, z_where_shift = p_where_mean.chunk(2, 1) else: z_where_scale, z_where_shift = \ Normal(p_where_mean, p_where_std).rsample().chunk(2, 1) # z_where_origin: (bs, dim, num_cell, num_cell) z_where_origin = \ torch.cat([z_where_scale.detach(), z_where_shift.detach()], dim=1) z_where_shift = \ (2. / self.args.arch.num_cell) * \ (self.offset + 0.5 + torch.tanh(z_where_shift)) - 1. scale, ratio = z_where_scale.chunk(2, 1) scale = scale.sigmoid() ratio = torch.exp(ratio) ratio_sqrt = ratio.sqrt() z_where_scale = torch.cat([scale / ratio_sqrt, scale * ratio_sqrt], dim=1) # z_where: (bs, dim, num_cell, num_cell) z_where = torch.cat([z_where_scale, z_where_shift], dim=1) if phase_use_mode: z_depth = p_depth_mean z_what = p_what_mean else: z_depth = Normal(p_depth_mean, p_depth_std).rsample() z_what = Normal(p_what_mean, p_what_std).rsample() z_what_reshape = z_what.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_what_dim). \ view(-1, self.args.z.z_what_dim, 1, 1) if self.args.data.inp_channel == 1 or not self.args.arch.phase_overlap: o = self.z_what_decoder_net(z_what_reshape) o = o.sigmoid() a = o.new_ones(o.size()) elif self.args.arch.phase_overlap: o, a = self.z_what_decoder_net(z_what_reshape).split([self.args.data.inp_channel, 1], dim=1) o, a = o.sigmoid(), a.sigmoid() else: raise NotImplemented lv = [z_pres, z_where, z_depth, z_what, z_where_origin] pa = [o, a] return pa, lv
def gumbel_softmax_bit_vector_sample(logits: torch.Tensor, temperature: float = 1.0, straight_through: bool = False): sample = RelaxedBernoulli(logits=logits, temperature=temperature).rsample() if straight_through: hard_sample = (logits > 0).to(torch.float) sample = sample + (hard_sample - sample).detach() return sample
def rsample_gumbel_softmax( distr: Distribution, n: int, temperature: torch.Tensor, straight_through: bool = False, ) -> torch.Tensor: if isinstance(distr, (Categorical, OneHotCategorical)): if straight_through: gumbel_distr = RelaxedOneHotCategoricalStraightThrough( temperature, probs=distr.probs) else: gumbel_distr = RelaxedOneHotCategorical(temperature, probs=distr.probs) elif isinstance(distr, Bernoulli): if straight_through: gumbel_distr = RelaxedBernoulliStraightThrough(temperature, probs=distr.probs) else: gumbel_distr = RelaxedBernoulli(temperature, probs=distr.probs) else: raise ValueError("Using Gumbel Softmax with non-discrete distribution") return gumbel_distr.rsample((n, ))
def forward(self, data_batch: dict, temperature=1.0, depth_scale=10.0, fast=False, **kwargs): pd_dict = dict() image = data_batch['image'] b, c0, h0, w0 = image.size() A = self.num_anchors_per_cell # ---------------------------------------------------------------------------- # # CNN encodes feature maps # ---------------------------------------------------------------------------- # fg_feature = self.fg_encoder(image) fg_feature = self.rpn(fg_feature) _, c1, h1, w1 = fg_feature.size() # ---------------------------------------------------------------------------- # # Relaxed Bernoulli z_pres # ---------------------------------------------------------------------------- # latent_pres = self.latent_pres(fg_feature) # (b, A, h1, w1) z_pres_p = torch.sigmoid(latent_pres) z_pres_p = z_pres_p.reshape(b, -1) # (b, A * h1 * w1) # In order to avoid gradient explosion at 0 and 1, clip z_pres_p = z_pres_p.clamp(min=self._eps, max=1.0 - self._eps) z_pres_post = RelaxedBernoulli(z_pres_p.new_tensor(temperature), probs=z_pres_p) if self.training: z_pres = z_pres_post.rsample() else: z_pres = z_pres_p pd_dict['z_pres'] = z_pres # (b, A * h1 * w1) pd_dict['z_pres_p'] = z_pres_p pd_dict['z_pres_post'] = z_pres_post # ---------------------------------------------------------------------------- # # Gaussian z_depth # ---------------------------------------------------------------------------- # latent_depth = self.latent_depth(fg_feature) # (b, A * 2, h1, w1) z_depth_loc = latent_depth.narrow(1, 0, A) z_depth_scale = F.softplus(latent_depth.narrow(1, A, A)) z_depth_post = Normal(z_depth_loc.reshape(b, -1), z_depth_scale.reshape(b, -1)) if self.training: z_depth = z_depth_post.rsample() # (b, A * h1 * w1) else: z_depth = z_depth_loc pd_dict['z_depth_post'] = z_depth_post # (b, A * h1 * w1) pd_dict['z_depth'] = z_depth # (b, A * h1 * w1) # ---------------------------------------------------------------------------- # # Gaussian z_where # (offset_x, offset_y, scale_x, scale_y) # ---------------------------------------------------------------------------- # latent_where = self.latent_where(fg_feature) # (b, A * 8, h1, w1) latent_where = latent_where.reshape(b, A, 8, h1, w1) # (b, A, 8, h1, w1) latent_where = latent_where.permute( 0, 1, 3, 4, 2).contiguous() # (b, A, h1, w1, 8) latent_where = latent_where.reshape(b, A * h1 * w1, 8) z_where_loc = latent_where.narrow(-1, 0, 4) z_where_scale = F.softplus(latent_where.narrow(-1, 4, 4)) z_where_post = Normal(z_where_loc, z_where_scale) if self.training: z_where = z_where_post.rsample() # (b, A * h1 * w1, 4) else: z_where = z_where_loc pd_dict['z_where_post'] = z_where_post # (b, A * h1 * w1, 4) # ---------------------------------------------------------------------------- # # Decode z_where to boxes # ---------------------------------------------------------------------------- # # (A * h1 * w1, 4) anchors = grid_anchors(self.cell_anchors, [h1, w1], [int(h0 / h1), int(w0 / w1)]) # (b, A * h1 * w1, 4) boxes = decode_boxes(anchors, z_where, clip_delta=self.clip_delta, image_shape=(h0, w0) if self.clip_to_image else None) pd_dict['boxes'] = boxes pd_dict['grid_size'] = (h1, w1) # ---------------------------------------------------------------------------- # # Normalize boxes # Note that spatial transform assumes [-1, 1] for coordinates # ---------------------------------------------------------------------------- # x_min, y_min, x_max, y_max = torch.split(boxes, 1, dim=-1) # (b, A * h1 * w1, 4) normalized_boxes = torch.cat( [x_min / w0, y_min / h0, x_max / w0, y_max / h0], dim=-1) # convert xyxy to xywh normalized_boxes = boxes_xyxy2xywh(normalized_boxes) if fast: pd_dict['normalized_boxes'] = normalized_boxes return pd_dict # ---------------------------------------------------------------------------- # # Gaussian z_what # ---------------------------------------------------------------------------- # # Crop glimpses, (b * A * h1 * w1, c, h2, w2) glimpses = image_to_glimpse(image, normalized_boxes, glimpse_shape=self.glimpse_shape) # Gaussian z_what glimpses_feature = self.glimpse_encoder(glimpses) latent_what = self.latent_what(glimpses_feature.flatten(1)) z_what_loc = latent_what[:, 0:self.z_what_size] z_what_scale = F.softplus(latent_what[:, self.z_what_size:]) z_what_post = Normal(z_what_loc, z_what_scale) if self.training: z_what = z_what_post.rsample() else: z_what = z_what_loc pd_dict['z_what_post'] = z_what_post # (b * A * h1 * w1, z_what_size) pd_dict['glimpses'] = glimpses.reshape(b, -1, c0, self.glimpse_shape[0], self.glimpse_shape[1]) pd_dict['glimpses_feature'] = glimpses_feature.reshape( b, A * h1 * w1, -1) # ---------------------------------------------------------------------------- # # Decode z_what # ---------------------------------------------------------------------------- # # (b * A * h1 * w1, (c+1), h2, w2) glimpses_recon = self.glimpse_decoder( z_what.unsqueeze(-1).unsqueeze(-1)) glimpses_recon = torch.sigmoid(glimpses_recon) glimpses_recon_reshape = glimpses_recon.reshape( b, -1, c0 + 1, self.glimpse_shape[0], self.glimpse_shape[1]) pd_dict['glimpse_rgb'] = glimpses_recon_reshape[:, :, :-1] pd_dict['glimpse_alpha'] = glimpses_recon_reshape[:, :, -1:] # ---------------------------------------------------------------------------- # # Foreground # ---------------------------------------------------------------------------- # # (b * A * h1 * w1, c0 + 1, h0, w0) fg_rgba = glimpse_to_image(glimpses_recon, normalized_boxes, image_shape=(h0, w0)) # (b, A * h1 * w1, c + 1, h, w) fg_rgba = fg_rgba.reshape(b, -1, c0 + 1, h0, w0) # (b, A * h1 * w1, 1, 1, 1) z_pres_reshape = z_pres.reshape(b, -1, 1, 1, 1) # Note that first c0 channels are rgb, and the last one is alpha. fg_rgb = fg_rgba[:, :, :-1] # (b, A * h1 * w1, c0, h0, w0) fg_alpha = fg_rgba[:, :, -1:] # (b, A * h1 * w1, 1, h0, w0) # Use foreground objects only fg_alpha_valid = fg_alpha * z_pres_reshape z_depth_reshape = z_depth.reshape(b, -1, 1, 1, 1) fg_weight = torch.softmax(fg_alpha_valid * depth_scale * torch.sigmoid(z_depth_reshape), dim=1) fg_mask_all = fg_alpha_valid * fg_weight fg_recon = (fg_rgb * fg_mask_all).sum(1) fg_mask = fg_mask_all.sum(1) pd_dict['fg_recon'] = fg_recon pd_dict['fg_mask'] = fg_mask # ---------------------------------------------------------------------------- # # Background # ---------------------------------------------------------------------------- # bg_feature = self.bg_encoder(image) bg_recon = self.bg_decoder(bg_feature) bg_recon = torch.sigmoid(bg_recon) bg_recon = bg_recon.reshape(b, c0, h0, w0) pd_dict['bg_recon'] = bg_recon return pd_dict
def forward(self, x, state=None): image = x canvas = torch.zeros_like(x.data) x, context = self.memory.init(image) c_data = context.data query = F.relu6(self.qdown(c_data)) mu = [] var = [] stages = [] masks = [] for i in range(self.count): x, inverse = self.memory.glimpse(x, image) out = self.memory(query) o_mu = self.mu(out) o_var = self.var(out) mu.append(o_mu) var.append(o_var) out = self.sample(o_mu, o_var) out = F.relu(self.sup(out)) out = self.decoder(out) inverse = inverse.view(out.size(0), 2, 3) grid = F.affine_grid(inverse, torch.Size([canvas.size(0), canvas.size(1) + 1, canvas.size(2), canvas.size(3)])) out = F.grid_sample(out.sigmoid(), grid) p = out[:, 0, :, :].unsqueeze(1) masks.append(p) out = out[:, 1:, :, :] dist = RelaxedBernoulli(torch.tensor([2.0]).to(p.device), probs=p) p = dist.rsample() canvas = canvas * (1 - p) out = out * p canvas += out if self.output_stages: square = self.square.clone().repeat(out.size(0), 1, 1, 1) square = F.grid_sample(square, grid) stage_image = out.data.clone() stage_image = stage_image + square stage_image = stage_image.clamp(0, 1) stages.append(stage_image.unsqueeze(1)) if state is not None: state[torchbearer.Y_TRUE] = image state[MU] = torch.stack(mu, dim=1) state[LOGVAR] = torch.stack(var, dim=1) state[MASKED_TARGET] = state[torchbearer.Y_TRUE].detach() * p.detach() if self.output_stages: stages.append(image.clone().unsqueeze(1)) state[STAGES] = torch.cat(stages, dim=1) return canvas
def get_probability_mask(self, batch_size): size = (batch_size, 1, 1) return RelaxedBernoulli(self.temperature, self.probability).rsample(size)
def gumbel_sigmoid(input: torch.Tensor, temp: float) -> torch.Tensor: """ gumbel sigmoid function """ return RelaxedBernoulli(temp, probs=input.sigmoid()).rsample()
def gumbel_sigmoid(input, temp): return RelaxedBernoulli(temp, probs=input).rsample()
def propagate(self, x, state_post_prev, state_prior_prev, z_prev, bg): """ Do propagation, conditioned on everything. Args: x: (B, 3, H, W), img (h, c), (h, c): each (B, N, D) z_prev: z_pres: (B, N, 1) z_depth: (B, N, 1) z_where: (B, N, 4) z_what: (B, N, D) Returns: h_post, c_post: (B, N, D) h_prior, c_prior: (B, N, D) z: z_pres: (B, N, 1) z_depth: (B, N, 1) z_where: (B, N, 4) z_what: (B, N, D) kl: kl_pres: (B,) kl_what: (B,) kl_where: (B,) kl_depth: (B,) proposal_region: (B, N, 4) """ z_pres_prev, z_depth_prev, z_where_prev, z_what_prev = z_prev B, N, _ = z_pres_prev.size() if N == 0: # No object is propagated return state_post_prev, state_prior_prev, z_prev, (0.0, 0.0, 0.0, 0.0), z_prev[2] h_post, c_post = state_post_prev h_prior, c_prior = state_prior_prev # Predict proposal locations, (B, N, 2) proposal_offset = self.pred_proposal(h_post) proposal = torch.zeros_like(z_where_prev) # Update size only proposal[..., 2:] = z_where_prev[..., 2:] proposal[ ..., :2] = z_where_prev[..., :2] + ARCH.PROPOSAL_UPDATE_MIN + ( ARCH.PROPOSAL_UPDATE_MAX - ARCH.PROPOSAL_UPDATE_MIN) * torch.sigmoid(proposal_offset) # Get proposal glimpses # (B*N, 3, H, W) x_repeat = torch.repeat_interleave(x[:, :3], N, dim=0) # (B*N, 3, H, W) proposal_glimpses = spatial_transform(x_repeat, proposal.view(B * N, 4), out_dims=(B * N, 3, *ARCH.GLIMPSE_SHAPE)) # (B, N, 3, H, W) proposal_glimpses = proposal_glimpses.view(B, N, 3, *ARCH.GLIMPSE_SHAPE) # (B, N, D) proposal_enc = self.proposal_encoder(proposal_glimpses) # (B, N, D) # This will be used to condition everything enc = torch.cat([proposal_enc, h_post], dim=-1) # (B, N, D) (z_pres_prob, z_depth_offset_loc, z_depth_offset_scale, z_where_offset_loc, z_where_offset_scale, z_what_offset_loc, z_what_offset_scale) = self.pres_depth_where_what_post_prop(enc) # Sampling z_pres_post = RelaxedBernoulli(self.tau, probs=z_pres_prob) z_pres = z_pres_post.rsample() z_pres = z_pres_prev * z_pres z_where_post = Normal(z_where_offset_loc, z_where_offset_scale) z_where_offset = z_where_post.rsample() z_where = torch.zeros_like(z_where_prev) # Scale z_where[..., :2] = z_where_prev[ ..., :2] + ARCH.Z_SCALE_UPDATE_SCALE * torch.tanh( z_where_offset[..., :2]) # Shift z_where[..., 2:] = z_where_prev[ ..., 2:] + ARCH.Z_SHIFT_UPDATE_SCALE * torch.tanh( z_where_offset[..., 2:]) z_depth_post = Normal(z_depth_offset_loc, z_depth_offset_scale) z_depth_offset = z_depth_post.rsample() z_depth = z_depth_prev + ARCH.Z_DEPTH_UPDATE_SCALE + z_depth_offset z_what_post = Normal(z_what_offset_loc, z_what_offset_scale) z_what_offset = z_what_post.rsample() z_what = z_what_prev + ARCH.Z_WHAT_UPDATE_SCALE * torch.tanh( z_what_offset) z = (z_pres, z_depth, z_where, z_what) # Update states state_post = self.temporal_encode(state_post_prev, z, bg, prior_or_post='post') state_prior = self.temporal_encode(state_prior_prev, z, bg, prior_or_post='prior') # Other priors (z_pres_prob, z_depth_offset_loc, z_depth_offset_scale, z_where_offset_loc, z_where_offset_scale, z_what_offset_loc, z_what_offset_scale) = self.pres_depth_where_what_prior_prop(h_prior) z_depth_prior = Normal(z_depth_offset_loc, z_depth_offset_scale) z_where_prior = Normal(z_where_offset_loc, z_where_offset_scale) z_what_prior = Normal(z_what_offset_loc, z_what_offset_scale) # This is not kl divergence. This is an auxialiary loss kl_pres = kl_divergence_bern_bern( z_pres_prob, torch.full_like(z_pres_prob, self.z_pres_prior_prob)) kl_depth = kl_divergence(z_depth_post, z_depth_prior) kl_depth *= z_pres kl_where = kl_divergence(z_where_post, z_where_prior) kl_where *= z_pres kl_what = kl_divergence(z_what_post, z_what_prior) kl_what *= z_pres # Reduced to (B,) # Again, this is not really kl kl_pres = kl_pres.flatten(start_dim=1).sum(-1) kl_depth = kl_depth.flatten(start_dim=1).sum(-1) kl_where = kl_where.flatten(start_dim=1).sum(-1) kl_what = kl_what.flatten(start_dim=1).sum(-1) assert kl_pres.size(0) == B kl = (kl_pres, kl_depth, kl_where, kl_what) return state_post, state_prior, z, kl, proposal
def propagate_gen(self, state_prev, z_prev, bg, sample=False): """ Args: h_prev, c_prev: (B, N, D) z_prev: z_pres_prev: (B, N, 1) z_depth_prev: (B, N, 1) z_where_prev: (B, N, 4) z_what_prev: (B, N, D) Returns: h, c: (B, N, D) z: z_pres: (B, N, 1) z_depth: (B, N, 1) z_where: (B, N, 4) z_what: (B, N, D) """ h_prev, c_prev = state_prev z_pres_prev, z_depth_prev, z_where_prev, z_what_prev = z_prev # (B, N, D) # TODO: z_pres_prior is not learned # All (B, N, D) (z_pres_prob, z_depth_offset_loc, z_depth_offset_scale, z_where_offset_loc, z_where_offset_scale, z_what_offset_loc, z_what_offset_scale) = self.pres_depth_where_what_prior_prop(h_prev) z_pres_prior = RelaxedBernoulli(temperature=self.tau, probs=z_pres_prob) z_pres = z_pres_prior.sample() z_pres = (z_pres > 0.5).float() z_pres = torch.ones_like(z_pres) z_pres = z_pres_prev * z_pres z_where_prior = Normal(z_where_offset_loc, z_where_offset_scale) z_where_offset = z_where_prior.rsample( ) if sample else z_where_offset_loc z_where = torch.zeros_like(z_where_prev) # Scale z_where[..., :2] = z_where_prev[ ..., :2] + ARCH.Z_SCALE_UPDATE_SCALE * torch.tanh( z_where_offset[..., :2]) # Shift z_where[..., 2:] = z_where_prev[ ..., 2:] + ARCH.Z_SHIFT_UPDATE_SCALE * torch.tanh( z_where_offset[..., 2:]) z_depth_prior = Normal(z_depth_offset_loc, z_depth_offset_scale) z_depth_offset = z_depth_prior.rsample( ) if sample else z_depth_offset_loc z_depth = z_depth_prev + ARCH.Z_DEPTH_UPDATE_SCALE * z_depth_offset z_what_prior = Normal(z_what_offset_loc, z_what_offset_scale) z_what_offset = z_what_prior.rsample() if sample else z_what_offset_loc z_what = z_what_prev + ARCH.Z_WHAT_UPDATE_SCALE * torch.tanh( z_what_offset) z = (z_pres, z_depth, z_where, z_what) state = self.temporal_encode(state_prev, z, bg, prior_or_post='prior') return state, z
def discover(self, x, z_prop, bg, start_id=0): """ Given current image and propagated objects, discover new objects Args: x: (B, D, H, W), current input image z_prop: z_pres_prop: (B, N, 1) z_depth_prop: (B, N, 1) z_where_prop: (B, N, 4) z_what_prop: (B, N, D) start_id: the id to start indexing Returns: (h_post, c_post): (B, N, D) (h_prior, c_prior): (B, N, D) z: z_pres: (B, N, 1) z_depth: (B, N, 1) z_where: (B, N, 4) z_what: (B, N, D) ids: (B, N) kl: kl_pres: (B,) kl_depth: (B,) kl_where: (B,) kl_what: (B,) ) """ B, *_ = x.size() # (B, D, G, G) x_enc = self.img_encoder(x) # For each discovery cell, we combine propagated objects weighted by distances # (B, D, G, G) prop_map = self.compute_prop_map(z_prop) # (B, D, G, G) enc = torch.cat([x_enc, prop_map], dim=1) (z_pres_post_prob, z_depth_post_loc, z_depth_post_scale, z_where_post_loc, z_where_post_scale, z_what_post_loc, z_what_post_scale) = self.pres_depth_where_what_post_disc(enc) # Compute posteriors. All (B, G*G, D) z_pres_post = RelaxedBernoulli(temperature=self.tau, probs=z_pres_post_prob) z_pres = z_pres_post.rsample() z_depth_post = Normal(z_depth_post_loc, z_depth_post_scale) z_depth = z_depth_post.rsample() z_where_post = Normal(z_where_post_loc, z_where_post_scale) z_where = z_where_post.rsample() z_where = self.z_where_relative_to_absolute(z_where) z_what_post = Normal(z_what_post_loc, z_what_post_scale) z_what = z_what_post.rsample() # Combine z = (z_pres, z_depth, z_where, z_what) # Rejection if ARCH.REJECTION: z = self.rejection(z, z_prop, ARCH.REJECTION_THRESHOLD) # Compute object ids # (B, G*G) + (B, 1) ids = torch.arange(ARCH.G**2, device=x_enc.device).expand( B, ARCH.G**2) + start_id[:, None] # Update temporal states state_post_prev = self.get_state_init(B, 'post') state_post = self.temporal_encode(state_post_prev, z, bg, prior_or_post='post') state_prior_prev = self.get_state_init(B, 'prior') state_prior = self.temporal_encode(state_prior_prev, z, bg, prior_or_post='prior') # All (B, G*G, D) # Conditional kl divergences kl_pres = kl_divergence_bern_bern( z_pres_post_prob, torch.full_like(z_pres_post_prob, self.z_pres_prior_prob)) z_depth_prior, z_where_prior, z_what_prior = self.get_discovery_priors( x.device) # where prior, (B, G*G, 4) kl_where = kl_divergence(z_where_post, z_where_prior) kl_where = kl_where * z_pres # what prior (B, G*G, D) kl_what = kl_divergence(z_what_post, z_what_prior) kl_what = kl_what * z_pres # what prior (B, G*G, D) kl_depth = kl_divergence(z_depth_post, z_depth_prior) kl_depth = kl_depth * z_pres # Sum over non-batch dimensions kl_pres = kl_pres.flatten(start_dim=1).sum(1) kl_where = kl_where.flatten(start_dim=1).sum(1) kl_what = kl_what.flatten(start_dim=1).sum(1) kl_depth = kl_depth.flatten(start_dim=1).sum(1) kl = (kl_pres, kl_depth, kl_where, kl_what) return state_post, state_prior, z, ids, kl