예제 #1
0
    def render(self, z):
        """
        Render z into an image
        Args:
            z:
                z_pres: (B, N, D)
                z_depth: (B, N, D)
                z_where: (B, N, D)
                z_what: (B, N, D)
        Returns:
            fg: (B, 3, H, W)
            alpha_map: (B, 1, H, W)
        """
        z_pres, z_depth, z_where, z_what = z

        B, N, _ = z_pres.size()
        # Reshape to make things easier
        z_pres = z_pres.view(B * N, -1)
        z_where = z_where.view(B * N, -1)
        z_what = z_what.view(B * N, -1)

        # Decoder z_what
        # (B*N, 3, H, W), (B*N, 1, H, W)
        o_att, alpha_att = self.glimpse_decoder(z_what)

        # (B*N, 1, H, W)
        alpha_att_hat = alpha_att * z_pres[..., None, None]
        # (B*N, 3, H, W)
        y_att = alpha_att_hat * o_att

        # To full resolution, (B*N, 1, H, W)
        y_att = spatial_transform(y_att,
                                  z_where, (B * N, 3, *ARCH.IMG_SHAPE),
                                  inverse=True)

        # To full resolution, (B*N, 1, H, W)
        alpha_att_hat = spatial_transform(alpha_att_hat,
                                          z_where, (B * N, 1, *ARCH.IMG_SHAPE),
                                          inverse=True)

        y_att = y_att.view(B, N, 3, *ARCH.IMG_SHAPE)
        alpha_att_hat = alpha_att_hat.view(B, N, 1, *ARCH.IMG_SHAPE)

        # (B, N, 1, H, W). H, W are glimpse size.
        importance_map = alpha_att_hat * torch.sigmoid(-z_depth[..., None,
                                                                None])
        importance_map = importance_map / (
            torch.sum(importance_map, dim=1, keepdim=True) + 1e-5)

        # Final fg (B, N, 3, H, W)
        fg = (y_att * importance_map).sum(dim=1)
        # Fg mask (B, N, 1, H, W)
        alpha_map = (importance_map * alpha_att_hat).sum(dim=1)

        return fg, alpha_map
예제 #2
0
def visualize(x,
              z_pres,
              z_where_scale,
              z_where_shift,
              rbox=rbox,
              gbox=gbox,
              num_obj=8 * 8):
    """
        x: (bs, 3, *img_shape)
        z_pres: (bs, 4, 4, 1)
        z_where_scale: (bs, 4, 4, 2)
        z_where_shift: (bs, 4, 4, 2)
    """
    B, _, *img_shape = x.size()
    bs = z_pres.size(0)
    # num_obj = 8 * 8
    z_pres = z_pres.view(-1, 1, 1, 1)
    # z_scale = z_where[:, :, :2].view(-1, 2)
    # z_shift = z_where[:, :, 2:].view(-1, 2)
    z_scale = z_where_scale.view(-1, 2)
    z_shift = z_where_shift.view(-1, 2)
    bbox = spatial_transform(z_pres * gbox + (1 - z_pres) * rbox,
                             torch.cat((z_scale, z_shift), dim=1),
                             torch.Size([bs * num_obj, 3, *img_shape]),
                             inverse=True)
    bbox = (bbox + torch.stack(num_obj *
                               (x, ), dim=1).view(-1, 3, *img_shape)).clamp(
                                   0.0, 1.0)
    return bbox
예제 #3
0
    def bg_attention(self, bg, z):
        """

        Args:
            bg: (B, C, H, W)
            z:

        Returns:
            (B, N, D)

        """
        # (B, N, D)
        z_pres, z_depth, z_where, z_what = z
        B, N, _ = z_pres.size()
        # (G, G), (G, G)
        proposal = z_where.clone()
        proposal[..., :2] += ARCH.BG_PROPOSAL_SIZE

        # Get proposal glimpses
        # (B*N, 3, H, W)
        x_repeat = torch.repeat_interleave(bg, N, dim=0)

        # (B*N, 3, H, W)
        proposal_glimpses = spatial_transform(x_repeat,
                                              proposal.view(B * N, 4),
                                              out_dims=(B * N, 3,
                                                        *ARCH.GLIMPSE_SHAPE))
        # (B, N, 3, H, W)
        proposal_glimpses = proposal_glimpses.view(B, N, 3,
                                                   *ARCH.GLIMPSE_SHAPE)
        # (B, N, D)
        proposal_enc = self.bg_attention_encoder(proposal_glimpses)

        return proposal_enc
예제 #4
0
파일: utils.py 프로젝트: zhixuan-lin/SPACE
def bbox_in_one(x, z_pres, z_where_scale, z_where_shift, gbox=gbox):
    B, _, *img_shape = x.size()
    B, N, _ = z_pres.size()
    z_pres = z_pres.view(-1, 1, 1, 1)
    z_scale = z_where_scale.view(-1, 2)
    z_shift = z_where_shift.view(-1, 2)
    # argmax_cluster = argmax_cluster.view(-1, 1, 1, 1)
    # kbox = boxes[argmax_cluster.view(-1)]
    bbox = spatial_transform(
        z_pres * gbox,  # + (1 - z_pres) * rbox,
        torch.cat((z_scale, z_shift), dim=1),
        torch.Size([B * N, 3, *img_shape]),
        inverse=True)
    bbox = (bbox.view(B, N, 3, *img_shape).sum(dim=1).clamp(0.0, 1.0) +
            x).clamp(0.0, 1.0)
    return bbox
예제 #5
0
파일: utils.py 프로젝트: zhixuan-lin/G-SWM
def draw_boxes(images, z_where, z_pres, ids):
    """
    Draw bounding boxes over image.
    Args:
        images: (..., 3, H, W)
        z_where: (..., N, 4)
        z_pres: (..., N, 1). This can be soft
        ids: (..., N). Id of each box

    Returns:
        images: images with boxes drawn
    """
    *ORI, N = ids.size()
    
    
    ORIN = np.prod(ORI + [N])
    
    *_, H, W = images.size()
    # Reshape everything for convenience
    # (ORIN,)
    ids = ids.reshape(-1)
    # (ORIN, 4)
    z_where = z_where.reshape(-1, 4)
    # (ORIN, 1)
    z_pres = z_pres.reshape(-1, 1)
    
    # Get boxes, (ORI*N, 3, H, W)
    boxes = get_boxes(ids)
    
    # (ORIN, 3, H, W)
    boxes = spatial_transform(boxes, z_where, (ORIN, 3, H, W), inverse=True)
    
    # Use z_pres as masks
    # (ORIN, 3, H, W) * (ORIN, 1)
    boxes = boxes * z_pres[..., None, None]
    
    # (*ORI, N, 3, H, W)
    boxes = boxes.reshape(*ORI, N, 3, H, W)
    
    # (*ORI, 3, H, W)
    boxes = boxes.sum(dim=-4).clamp_max(1.0)
    
    # (*ORI, 3, H, W)
    img_box = (images + boxes).clamp_max(1.0)
    
    return img_box
예제 #6
0
def bbox_in_one(x, z_pres, z_where_scale, z_where_shift, gbox=gbox):
    torchvision.utils.save_image(x, '../output/current_im.png')
    B, _, *img_shape = x.size()
    B, N, _ = z_pres.size()
    #from (B, G*G, D) to (B*G*G, D) #N = G*G
    '''z_p = np.squeeze(z_pres.detach().cpu().numpy())
    z_sc = np.squeeze(z_where_scale.detach().cpu().numpy(), axis=0)
    z_sh = np.squeeze(z_where_shift.detach().cpu().numpy(), axis=0)
    indices = z_p > 0.98
    coordinates = z_sh[indices]
    sizes = z_sc[indices]
    fig, ax = plt.subplots()
    ax.scatter([-1, -1, 1, 1], [-1, 1, -1, 1])
    ax.scatter(coordinates[:, 0], coordinates[:, 1], marker='s')
    for i in range(len(coordinates)):
        # Create a Rectangle patch
        x_c = coordinates[i, 0] - sizes[i, 0]
        y_c = coordinates[i, 1] - sizes[i, 1]
        rect = patches.Rectangle((x_c, y_c),
                                 sizes[i, 0], sizes[i, 1], linewidth=1, edgecolor='g', facecolor='none')
        # Add the patch to the Axes
        ax.add_patch(rect)
    plt.savefig('../output/coordinates_im.png')
    plt.close()'''

    z_pres = z_pres.view(-1, 1, 1, 1)
    z_scale = z_where_scale.view(-1, 2)
    z_shift = z_where_shift.view(-1, 2)
    # argmax_cluster = argmax_cluster.view(-1, 1, 1, 1)
    # kbox = boxes[argmax_cluster.view(-1)]
    bbox = spatial_transform(
        z_pres * gbox,  # + (1 - z_pres) * rbox,
        torch.cat((z_scale, z_shift), dim=1),
        torch.Size([B * N, 3, *img_shape]),
        inverse=True)

    bbox_im = bbox
    bbox_im = bbox_im.view(B, N, 3, *img_shape).sum(dim=1).clamp(0.0, 1.0)
    torchvision.utils.save_image(bbox_im, '../output/bbox_im.png')
    #to (B, G*G, 3, H, W)
    bbox = (bbox.view(B, N, 3, *img_shape).sum(dim=1).clamp(0.0, 1.0) +
            x).clamp(0.0, 1.0)

    return bbox
예제 #7
0
def add_bbox(x,
             score,
             z_where_scale,
             z_where_shift,
             rbox=rbox,
             gbox=gbox,
             num_obj=8 * 8):
    B, _, *img_shape = x.size()
    bs = score.size(0)
    score = score.view(-1, 1, 1, 1)
    z_scale = z_where_scale.view(-1, 2)
    z_shift = z_where_shift.view(-1, 2)
    bbox = spatial_transform(score * gbox + (1 - score) * rbox,
                             torch.cat((z_scale, z_shift), dim=1),
                             torch.Size([bs * num_obj, 3, *img_shape]),
                             inverse=True)
    bbox = (bbox + x.repeat(1, 3, 1, 1).view(-1, 3, *img_shape)).clamp(
        0.0, 1.0)
    return bbox
예제 #8
0
    def forward(self,
                x,
                img_enc,
                alpha_map_prop,
                ids_prop,
                lengths,
                t,
                eps=1e-15):
        """
            :param z_what_prop: (bs, max_num_obj, dim)
            :param z_where_prop: (bs, max_num_obj, 4)
            :param z_pres_prop: (bs, max_num_obj, 1)
            :param alpha_map_prop: (bs, 1, img_h, img_w)
        """
        bs = x.size(0)
        device = x.device
        alpha_map_prop = alpha_map_prop.detach()

        max_num_disc_obj = (self.max_num_obj - lengths).long()

        self.prior_z_pres_prob = linear_annealing(
            self.args.global_step, self.z_pres_anneal_start_step,
            self.z_pres_anneal_end_step, self.z_pres_anneal_start_value,
            self.z_pres_anneal_end_value, device)

        # z_where: (bs * num_cell_h * num_cell_w, 4)
        # z_pres, z_depth, z_pres_logits: (bs, dim, num_cell_h, num_cell_w)
        z_where, z_pres, z_depth, z_where_mean, z_where_std, \
        z_depth_mean, z_depth_std, z_pres_logits, z_pres_y, z_where_origin = self.ProposalNet(
            img_enc, alpha_map_prop, self.args.tau, t, gen_pres_probs=x.new_ones(1) * self.args.gen_disc_pres_probs,
            gen_depth_mean=self.prior_depth_mean, gen_depth_std=self.prior_depth_std,
            gen_where_mean=self.prior_where_mean, gen_where_std=self.prior_where_std
        )
        num_cell_h, num_cell_w = z_pres.shape[2], z_pres.shape[3]

        q_z_where = Normal(z_where_mean, z_where_std)
        q_z_depth = Normal(z_depth_mean, z_depth_std)

        z_pres_orgin = z_pres

        if self.args.phase_generate and t >= self.args.observe_frames:
            z_what_mean, z_what_std = self.prior_what_mean.view(1, 1).expand(bs * self.args.num_cell_h *
                                                                             self.args.num_cell_w, z_what_dim), \
                                      self.prior_what_std.view(1, 1).expand(bs * self.args.num_cell_h *
                                                                            self.args.num_cell_w, z_what_dim)
            x_att = x.new_zeros(1)
        else:
            # (bs * num_cell_h * num_cell_w, 3, glimpse_size, glimpse_size)
            x_att = spatial_transform(
                torch.stack(num_cell_h * num_cell_w * (x, ),
                            dim=1).view(-1, 3, img_h, img_w),
                z_where,
                (bs * num_cell_h * num_cell_w, 3, glimpse_size, glimpse_size),
                inverse=False)

            # (bs * num_cell_h * num_cell_w, dim)
            z_what_mean, z_what_std = self.z_what_net(x_att)
            z_what_std = F.softplus(z_what_std)

        q_z_what = Normal(z_what_mean, z_what_std)

        z_what = q_z_what.rsample()

        # (bs * num_cell_h * num_cell_w, dim, glimpse_size, glimpse_size)
        o_att, alpha_att = self.glimpse_dec(z_what)

        # Rejection
        if phase_rejection and t > 0:
            alpha_map_raw = spatial_transform(
                alpha_att,
                z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w),
                inverse=True)

            alpha_map_proposed = (alpha_map_raw > 0.3).float()

            alpha_map_prop = (alpha_map_prop > 0.1).float().view(bs, 1, 1, img_h, img_w) \
                .expand(-1, num_cell_h * num_cell_w, -1, -1, -1).contiguous().view(-1, 1, img_h, img_w)

            alpha_map_intersect = alpha_map_proposed * alpha_map_prop

            explained_ratio = alpha_map_intersect.view(bs * num_cell_h * num_cell_w, -1).sum(1) / \
                              (alpha_map_proposed.view(bs * num_cell_h * num_cell_w, -1).sum(1) + eps)

            pres_mask = (explained_ratio <
                         self.args.explained_ratio_threshold).view(
                             bs, 1, num_cell_h, num_cell_w).float()

            z_pres = z_pres * pres_mask

        # The following "if" is useful only if you don't have high-memory GPUs, better to remove it if you do
        if self.training and phase_obj_num_contrain:
            z_pres = z_pres.view(bs, -1)

            z_pres_threshold = z_pres.sort(
                dim=1, descending=True)[0][torch.arange(bs), max_num_disc_obj]

            z_pres_mask = (z_pres > z_pres_threshold.view(bs, -1)).float()

            if self.args.phase_generate and t >= self.args.observe_frames:
                z_pres_mask = x.new_zeros(z_pres_mask.size())

            z_pres = z_pres * z_pres_mask

            z_pres = z_pres.view(bs, 1, num_cell_h, num_cell_w)

        alpha_att_hat = alpha_att * z_pres.view(-1, 1, 1, 1)

        y_att = alpha_att_hat * o_att

        # (bs * num_cell_h * num_cell_w, 3, img_h, img_w)
        y_each_cell = spatial_transform(
            y_att,
            z_where, (bs * num_cell_h * num_cell_w, 3, img_h, img_w),
            inverse=True)

        # (bs * num_cell_h * num_cell_w, 1, glimpse_size, glimpse_size)
        importance_map = alpha_att_hat * torch.sigmoid(-z_depth).view(
            -1, 1, 1, 1)
        # importance_map = -z_depth.view(-1, 1, 1, 1).expand_as(alpha_att_hat)
        # (bs * num_cell_h * num_cell_w, 1, img_h, img_w)
        importance_map_full_res = spatial_transform(
            importance_map,
            z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w),
            inverse=True)

        # (bs * num_cell_h * num_cell_w, 1, img_h, img_w)
        alpha_map = spatial_transform(
            alpha_att_hat,
            z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w),
            inverse=True)

        # (bs * num_cell_h * num_cell_w, z_what_dim)
        kl_z_what = kl_divergence(q_z_what, self.p_z_what) * z_pres_orgin.view(
            -1, 1)
        # (bs, num_cell_h * num_cell_w, z_what_dim)
        kl_z_what = kl_z_what.view(-1, num_cell_h * num_cell_w, z_what_dim)
        # (bs * num_cell_h * num_cell_w, z_depth_dim)
        kl_z_depth = kl_divergence(q_z_depth, self.p_z_depth) * z_pres_orgin
        # (bs, num_cell_h * num_cell_w, z_depth_dim)
        kl_z_depth = kl_z_depth.view(-1, num_cell_h * num_cell_w, z_depth_dim)
        # (bs, dim, num_cell_h, num_cell_w)
        kl_z_where = kl_divergence(q_z_where, self.p_z_where) * z_pres_orgin
        if phase_rejection and t > 0:
            kl_z_pres = calc_kl_z_pres_bernoulli(
                z_pres_logits, self.prior_z_pres_prob * pres_mask +
                self.z_pres_masked_prior * (1 - pres_mask))
        else:
            kl_z_pres = calc_kl_z_pres_bernoulli(z_pres_logits,
                                                 self.prior_z_pres_prob)

        kl_z_pres = kl_z_pres.view(-1, num_cell_h * num_cell_w, z_pres_dim)

        ########################################### Compute log importance ############################################
        log_imp = x.new_zeros(bs, 1)
        if not self.training and self.args.phase_nll:
            z_pres_orgin_binary = (z_pres_orgin > 0.5).float()
            # (bs * num_cell_h * num_cell_w, dim)
            log_imp_what = (
                self.p_z_what.log_prob(z_what) -
                q_z_what.log_prob(z_what)) * z_pres_orgin_binary.view(-1, 1)
            log_imp_what = log_imp_what.view(-1, num_cell_h * num_cell_w,
                                             z_what_dim)
            # (bs, dim, num_cell_h, num_cell_w)
            log_imp_depth = (self.p_z_depth.log_prob(z_depth) -
                             q_z_depth.log_prob(z_depth)) * z_pres_orgin_binary
            # (bs, dim, num_cell_h, num_cell_w)
            log_imp_where = (
                self.p_z_where.log_prob(z_where_origin) -
                q_z_where.log_prob(z_where_origin)) * z_pres_orgin_binary
            if phase_rejection and t > 0:
                p_z_pres = self.prior_z_pres_prob * pres_mask + self.z_pres_masked_prior * (
                    1 - pres_mask)
            else:
                p_z_pres = self.prior_z_pres_prob

            z_pres_binary = (z_pres > 0.5).float()

            log_pres_prior = z_pres_binary * torch.log(p_z_pres + eps) + \
                             (1 - z_pres_binary) * torch.log(1 - p_z_pres + eps)

            log_pres_pos = z_pres_binary * torch.log(torch.sigmoid(z_pres_logits) + eps) + \
                           (1 - z_pres_binary) * torch.log(1 - torch.sigmoid(z_pres_logits) + eps)

            log_imp_pres = log_pres_prior - log_pres_pos

            log_imp = log_imp_what.flatten(start_dim=1).sum(dim=1) + log_imp_depth.flatten(start_dim=1).sum(1) + \
                      log_imp_where.flatten(start_dim=1).sum(1) + log_imp_pres.flatten(start_dim=1).sum(1)

        ######################################## End of Compute log importance #########################################

        # (bs, num_cell_h * num_cell_w)
        ids = torch.arange(num_cell_h * num_cell_w).view(1, -1).expand(bs, -1).to(x.device).float() + \
              ids_prop.max(dim=1, keepdim=True)[0] + 1

        if self.args.log_phase:
            self.log = {
                'z_what': z_what,
                'z_where': z_where,
                'z_pres': z_pres,
                'z_pres_logits': z_pres_logits,
                'z_what_std': q_z_what.stddev,
                'z_what_mean': q_z_what.mean,
                'z_where_std': q_z_where.stddev,
                'z_where_mean': q_z_where.mean,
                'x_att': x_att,
                'y_att': y_att,
                'prior_z_pres_prob': self.prior_z_pres_prob.unsqueeze(0),
                'o_att': o_att,
                'alpha_att_hat': alpha_att_hat,
                'alpha_att': alpha_att,
                'y_each_cell': y_each_cell,
                'z_depth': z_depth,
                'z_depth_std': q_z_depth.stddev,
                'z_depth_mean': q_z_depth.mean,
                # 'importance_map_full_res_norm': importance_map_full_res_norm,
                'z_pres_y': z_pres_y,
                'ids': ids
            }
        else:
            self.log = {}

        return y_each_cell.view(bs, num_cell_h * num_cell_w, 3, img_h, img_w), \
               alpha_map.view(bs, num_cell_h * num_cell_w, 1, img_h, img_w), \
               importance_map_full_res.view(bs, num_cell_h * num_cell_w, 1, img_h, img_w), \
               z_what.view(bs, num_cell_h * num_cell_w, -1), z_where.view(bs, num_cell_h * num_cell_w, -1), \
               torch.zeros_like(z_where.view(bs, num_cell_h * num_cell_w, -1)), \
               z_depth.view(bs, num_cell_h * num_cell_w, -1), z_pres.view(bs, num_cell_h * num_cell_w, -1), ids, \
               kl_z_what.flatten(start_dim=1).sum(dim=1), \
               kl_z_where.flatten(start_dim=1).sum(dim=1), \
               kl_z_pres.flatten(start_dim=1).sum(dim=1), \
               kl_z_depth.flatten(start_dim=1).sum(dim=1), \
               log_imp, self.log
예제 #9
0
    def forward(self,
                x,
                act,
                img_enc,
                z_what_pre,
                z_where_pre,
                z_where_bias_pre,
                z_depth_pre,
                z_pres_pre,
                cumsum_one_minus_z_pres,
                ids_pre,
                lengths,
                max_length,
                t,
                eps=1e-15):
        """

        :param x: input image (bs, c, h, w)
        :param act: action (bs, act_dim)
        :param img_enc: input image encode (bs, c, num_cell_h, num_cell_w)
        :param z_what_pre: (bs, max_num_obj, dim)
        :param z_where_pre: (bs, max_num_obj, dim)
        :param z_depth_pre: (bs, max_num_obj, dim)
        :param z_pres_pre: (bs, max_num_obj, dim)
        :param cumsum_one_minus_z_pres: (bs, max_num_obj, dim)
        :param lengths: (bs)
        :return:
        """
        '''
        We zero img infos to push model not using info there
        '''
        #x = x * 0

        bs = x.size(0)
        device = x.device
        max_num_obj = max_length
        obj_mask = (z_pres_pre.view(bs, max_num_obj) != 0).float()

        bns = bs * max_num_obj

        # node_rep: B x N x infer_graph_struct_node_dims
        obj_rep = torch.cat([
            z_where_pre, z_pres_pre, z_what_pre, z_where_bias_pre, z_depth_pre
        ],
                            dim=2)

        if self.node_type == None:
            # node_type_logits: bs x max_num_obj x 2
            node_type_logits = self.infer_node_type(node_rep=z_what_pre,
                                                    ignore_edge=True)
            # node_type: bs x max_num_obj
            self.node_type = gumbel_softmax(node_type_logits.view(
                bs, max_num_obj),
                                            hard=hard_gumble_softmax)

        node_type = self.node_type
        # expanded_action: bs x max_num_obj x action_dim
        expanded_node_type = self.node_type.unsqueeze(2).expand(
            -1, -1, action_dim).contiguous().to(device)
        expanded_action = act.unsqueeze(1).expand(-1, max_num_obj,
                                                  -1).contiguous().to(device)

        # edge_type_logits: bs x max_num_obj x 2
        edge_type_logits = self.infer_edge_type(node_rep=obj_rep,
                                                ignore_node=True)

        if edge_share:
            edge_type_logits = (edge_type_logits +
                                torch.transpose(edge_type_logits, 1, 2)) / 2.

        # edge_type: (bs * max_num_obj * max_num_obj) x 2
        edge_type = gumbel_softmax(edge_type_logits.view(
            bs * max_num_obj * max_num_obj, 2),
                                   hard=hard_gumble_softmax)
        expanded_action = expanded_action * expanded_node_type

        obj_act_inp = torch.cat([
            z_where_pre, z_pres_pre, z_what_pre, z_where_bias_pre, z_depth_pre,
            expanded_action
        ],
                                dim=2)
        object_transit_out = self.object_transit_net(obj_act_inp,
                                                     None,
                                                     edge_type,
                                                     start_idx=edge_st_idx,
                                                     ignore_edge=True)

        object_transit_out = self.object_transit_mlp_net(
            obj_act_inp.view(bns, -1).contiguous()).view(bs, max_num_obj, -1)
        #import pdb; pdb.set_trace()

        # z_where transition
        z_where_transit_bias_net_inp = torch.cat(
            [object_transit_out, z_what_pre, z_where_pre, z_where_bias_pre],
            dim=2)
        # bns x dim
        z_where_transit_bias_net_inp = z_where_transit_bias_net_inp.view(
            bns, -1).contiguous()

        object_transit_out = object_transit_out.view(bns, -1).contiguous()

        z_where_bias_mean, z_where_bias_std = \
            self.z_where_transit_bias_net(z_where_transit_bias_net_inp).chunk(2, -1)
        z_where_bias_std = F.softplus(z_where_bias_std + self.z_where_std_bias)
        z_where_bias_dist = Normal(z_where_bias_mean, z_where_bias_std)
        z_where_bias = z_where_bias_dist.rsample()

        z_where_pre = z_where_pre.view(bns, -1).contiguous()
        z_where_shift = z_where_pre[:,
                                    2:] + self.where_update_scale * z_where_bias[:,
                                                                                 2:].tanh(
                                                                                 )

        scale, ratio = z_where_bias[:, :2].tanh().chunk(2, 1)
        scale = self.args.size_anc + self.args.var_s * scale  # add bias to let masking do its job
        ratio = self.args.ratio_anc + self.args.var_anc * ratio
        ratio_sqrt = ratio.sqrt()

        z_where = torch.cat(
            (scale / ratio_sqrt, scale * ratio_sqrt, z_where_shift), dim=1)
        # # always within the image
        z_where = torch.cat(
            (z_where[:, :2], z_where[:, 2:].clamp(-1.05, 1.05)), dim=1)

        # z_what transit
        # encode
        x_att = \
            spatial_transform(
                x.unsqueeze(1).expand(-1, max_num_obj, -1, -1, -1).contiguous().view(bns, 3, img_h, img_w), z_where,
                (bns, 3, glimpse_size, glimpse_size), inverse=False
            )

        z_what_from_enc_mean, z_what_from_enc_std = self.z_what_net(x_att)
        z_what_from_enc_std = F.softplus(z_what_from_enc_std)
        z_what_encode_dist = Normal(z_what_from_enc_mean, z_what_from_enc_std)

        # transit
        z_what_from_transit_mean, z_what_from_transit_std = \
            self.z_what_from_transit_net(object_transit_out).chunk(2, -1)

        z_what_from_transit_std = F.softplus(z_what_from_transit_std)
        z_what_transit_dist = Normal(z_what_from_transit_mean,
                                     z_what_from_transit_std)

        if True or self.args.phase_generate and t >= self.args.observe_frames:
            z_what_mean = z_what_from_transit_mean
            z_what_std = z_what_from_transit_std
            z_what_dist = z_what_transit_dist
        else:
            z_what_gate_net_inp = object_transit_out
            forget_gate, input_gate = self.z_what_gate_net(
                z_what_gate_net_inp).chunk(2, -1)
            z_what_mean = input_gate * z_what_from_enc_mean + \
                      forget_gate * z_what_from_transit_mean

            z_what_std = F.softplus(input_gate * z_what_from_enc_std + \
                                forget_gate * z_what_from_transit_std)

            z_what_dist = Normal(z_what_mean, z_what_std)

        z_what = z_what_dist.rsample()

        # z depth transit
        z_depth_pre = z_depth_pre.view(bns, -1).contiguous()
        z_depth_transit_net_inp = torch.cat(
            [object_transit_out, z_what, z_depth_pre], dim=1)
        z_depth_mean, z_depth_std = self.z_depth_transit_net(
            z_depth_transit_net_inp).chunk(2, -1)
        z_depth_std = F.softplus(z_depth_std)
        z_depth_dist = Normal(z_depth_mean, z_depth_std)
        z_depth = z_depth_dist.rsample()

        # z_pres bns, dim
        z_pres_transit_inp = torch.cat(
            [object_transit_out, z_where, z_where_bias, z_what], dim=1)
        z_pres_logits = pres_logit_factor * torch.tanh(
            self.z_pres_transit(z_pres_transit_inp) + self.z_pres_logits_bias)

        z_pres_dist = NumericalRelaxedBernoulli(logits=z_pres_logits,
                                                temperature=self.args.tau)
        z_pres_y = z_pres_dist.rsample()
        #z_pres = torch.sigmoid(z_pres_y+1000)
        z_pres = torch.ones(z_pres_y.shape).to(device)  # make it all 1 for now

        o_att, alpha_att = self.glimpse_dec_net(z_what)

        alpha_att_hat = alpha_att * z_pres.view(-1, 1, 1, 1)
        y_att = alpha_att_hat * o_att

        # (bs, 3, img_h, img_w)
        y_each_obj = spatial_transform(y_att,
                                       z_where, (bns, 3, img_h, img_w),
                                       inverse=True)

        # (batch_size_t, 1, glimpse_size, glimpse_size)
        importance_map = alpha_att_hat * torch.sigmoid(-z_depth).view(
            -1, 1, 1, 1)
        #import pdb; pdb.set_trace()
        # (batch_size_t, 1, img_h, img_w)
        importance_map_full_res = spatial_transform(importance_map,
                                                    z_where,
                                                    (bns, 1, img_h, img_w),
                                                    inverse=True)

        # (batch_size_t, 1, img_h, img_w)
        alpha_map = spatial_transform(alpha_att_hat,
                                      z_where, (bns, 1, img_h, img_w),
                                      inverse=True)
        final_z_pres_mask = z_pres.squeeze() * obj_mask.view(bns)

        kl_z_pres = torch.zeros(bs).to(device)
        kl_z_what = \
            (kl_divergence(z_what_transit_dist, z_what_encode_dist).sum(1) * \
             z_pres.squeeze() * obj_mask.view(bns)).view(bs, max_num_obj).sum(1)
        kl_z_where = torch.zeros(bs).to(device)
        kl_z_depth = torch.zeros(bs).to(device)

        #pres_edge_type_logits = torch.index_select(edge_type_logits.view(-1,2), 0, (final_z_pres_mask == 1).nonzero().squeeze())
        #edge_type_prior = torch.FloatTensor(np.array([1-edge_pos_prior, edge_pos_prior])).cuda()
        #kl_edge_type = -criterionH(pres_edge_type_logits.view(-1,2), edge_type_prior)
        tmp_obj_mask = obj_mask.view(bs, max_num_obj)
        edge_mask = torch.bmm(tmp_obj_mask.unsqueeze(2),
                              tmp_obj_mask.unsqueeze(1)).view(-1)

        kl_edge_type = \
            (calc_kl_z_edge_bernoulli(edge_type_logits.view(-1,2), torch.tensor(edge_pos_prior)) *
                edge_mask.view(-1)).view(bs, max_num_obj * max_num_obj).sum(1)
        ########################################### Compute log importance ############################################
        log_imp = x.new_zeros(bs)
        if not self.training and self.args.phase_nll:
            z_pres_binary = (z_pres > 0.5).float()
            # (bns, dim)
            log_imp = torch.zeros(bs, 1).to(device)

        ######################################## End of Compute log importance #########################################
        z_what_all = z_what.view(bs, max_num_obj, -1) * obj_mask.view(
            bs, max_num_obj, 1)
        z_where_dummy = x.new_ones(
            bs, max_num_obj, (z_where_scale_dim + z_where_shift_dim)) * .5
        z_where_dummy[:, :, z_where_scale_dim:] = 2
        z_where_all = z_where.view(bs, max_num_obj, -1) * obj_mask.view(bs, max_num_obj, 1) + \
                      z_where_dummy * (1 - obj_mask.view(bs, max_num_obj, 1))
        z_where_bias_all = z_where_bias.view(
            bs, max_num_obj, -1) * obj_mask.view(bs, max_num_obj, 1)
        z_pres_all = z_pres.view(bs, max_num_obj, -1) * obj_mask.view(
            bs, max_num_obj, 1)

        z_depth_all = z_depth.view(bs, max_num_obj, -1) * obj_mask.view(
            bs, max_num_obj, 1)
        y_each_obj_all = \
            y_each_obj.view(bs, max_num_obj, 3, img_h, img_w) * obj_mask.view(bs, max_num_obj, 1, 1, 1)
        alpha_map_all = \
            alpha_map.view(bs, max_num_obj, 1, img_h, img_w) * obj_mask.view(bs, max_num_obj, 1, 1, 1)
        importance_map_all = \
            importance_map_full_res.view(bs, max_num_obj, 1, img_h, img_w) * \
            obj_mask.view(bs, max_num_obj, 1, 1, 1)

        cumsum_one_minus_z_pres = cumsum_one_minus_z_pres.view(
            bs, max_num_obj, -1)

        if self.args.log_phase:
            self.log = {
                'z_what':
                z_what_all,
                'z_where':
                z_where_all,
                'z_pres':
                z_pres_all,
                'z_what_std':
                z_what_std.view(bs, max_num_obj, -1),
                'z_what_mean':
                z_what_mean.view(bs, max_num_obj, -1),
                'z_where_bias_std':
                z_where_bias_std.view(bs, max_num_obj, -1),
                'z_where_bias_mean':
                z_where_bias_mean.view(bs, max_num_obj, -1),
                'lengths':
                lengths,
                'z_depth':
                z_depth_all,
                'z_depth_std':
                z_depth_std.view(bs, max_num_obj, -1),
                'z_depth_mean':
                z_depth_mean.view(bs, max_num_obj, -1),
                'y_each_obj':
                y_each_obj_all.view(bs, max_num_obj, 3, img_h, img_w),
                'alpha_map':
                alpha_map_all.view(bs, max_num_obj, 1, img_h, img_w),
                'importance_map':
                importance_map_all.view(bs, max_num_obj, 1, img_h, img_w),
                'z_pres_logits':
                z_pres_logits.view(bs, max_num_obj, -1),
                'z_pres_y':
                z_pres_y.view(bs, max_num_obj, -1),
                'o_att':
                o_att.view(bs, max_num_obj, 3, glimpse_size, glimpse_size),
                'z_where_bias':
                z_where_bias_all,
                'node_type':
                node_type,
                'edge_type':
                edge_type,
                'ids':
                ids_pre
            }
        else:
            self.log = {}
        #print(z_pres_all)
        return y_each_obj_all, alpha_map_all, importance_map_all, z_what_all, z_where_all, \
               z_where_bias_all, z_depth_all, z_pres_all, ids_pre, kl_z_what, kl_z_where, kl_z_depth, \
               kl_z_pres, kl_edge_type, cumsum_one_minus_z_pres, log_imp, self.log
예제 #10
0
    def propagate(self, x, state_post_prev, state_prior_prev, z_prev, bg):
        """
        Do propagation, conditioned on everything.
        Args:
            x: (B, 3, H, W), img
            (h, c), (h, c): each (B, N, D)
            z_prev:
                z_pres: (B, N, 1)
                z_depth: (B, N, 1)
                z_where: (B, N, 4)
                z_what: (B, N, D)

        Returns:
            h_post, c_post: (B, N, D)
            h_prior, c_prior: (B, N, D)
            z:
                z_pres: (B, N, 1)
                z_depth: (B, N, 1)
                z_where: (B, N, 4)
                z_what: (B, N, D)
            kl:
                kl_pres: (B,)
                kl_what: (B,)
                kl_where: (B,)
                kl_depth: (B,)
            proposal_region: (B, N, 4)

        """
        z_pres_prev, z_depth_prev, z_where_prev, z_what_prev = z_prev
        B, N, _ = z_pres_prev.size()

        if N == 0:
            # No object is propagated
            return state_post_prev, state_prior_prev, z_prev, (0.0, 0.0, 0.0,
                                                               0.0), z_prev[2]

        h_post, c_post = state_post_prev
        h_prior, c_prior = state_prior_prev

        # Predict proposal locations, (B, N, 2)
        proposal_offset = self.pred_proposal(h_post)
        proposal = torch.zeros_like(z_where_prev)
        # Update size only
        proposal[..., 2:] = z_where_prev[..., 2:]
        proposal[
            ..., :2] = z_where_prev[..., :2] + ARCH.PROPOSAL_UPDATE_MIN + (
                ARCH.PROPOSAL_UPDATE_MAX -
                ARCH.PROPOSAL_UPDATE_MIN) * torch.sigmoid(proposal_offset)

        # Get proposal glimpses
        # (B*N, 3, H, W)
        x_repeat = torch.repeat_interleave(x[:, :3], N, dim=0)

        # (B*N, 3, H, W)
        proposal_glimpses = spatial_transform(x_repeat,
                                              proposal.view(B * N, 4),
                                              out_dims=(B * N, 3,
                                                        *ARCH.GLIMPSE_SHAPE))
        # (B, N, 3, H, W)
        proposal_glimpses = proposal_glimpses.view(B, N, 3,
                                                   *ARCH.GLIMPSE_SHAPE)
        # (B, N, D)
        proposal_enc = self.proposal_encoder(proposal_glimpses)
        # (B, N, D)
        # This will be used to condition everything
        enc = torch.cat([proposal_enc, h_post], dim=-1)

        # (B, N, D)
        (z_pres_prob, z_depth_offset_loc, z_depth_offset_scale,
         z_where_offset_loc, z_where_offset_scale, z_what_offset_loc,
         z_what_offset_scale) = self.pres_depth_where_what_post_prop(enc)

        # Sampling
        z_pres_post = RelaxedBernoulli(self.tau, probs=z_pres_prob)
        z_pres = z_pres_post.rsample()
        z_pres = z_pres_prev * z_pres

        z_where_post = Normal(z_where_offset_loc, z_where_offset_scale)
        z_where_offset = z_where_post.rsample()
        z_where = torch.zeros_like(z_where_prev)
        # Scale
        z_where[..., :2] = z_where_prev[
            ..., :2] + ARCH.Z_SCALE_UPDATE_SCALE * torch.tanh(
                z_where_offset[..., :2])
        # Shift
        z_where[..., 2:] = z_where_prev[
            ..., 2:] + ARCH.Z_SHIFT_UPDATE_SCALE * torch.tanh(
                z_where_offset[..., 2:])

        z_depth_post = Normal(z_depth_offset_loc, z_depth_offset_scale)
        z_depth_offset = z_depth_post.rsample()
        z_depth = z_depth_prev + ARCH.Z_DEPTH_UPDATE_SCALE + z_depth_offset

        z_what_post = Normal(z_what_offset_loc, z_what_offset_scale)
        z_what_offset = z_what_post.rsample()
        z_what = z_what_prev + ARCH.Z_WHAT_UPDATE_SCALE * torch.tanh(
            z_what_offset)
        z = (z_pres, z_depth, z_where, z_what)

        # Update states
        state_post = self.temporal_encode(state_post_prev,
                                          z,
                                          bg,
                                          prior_or_post='post')
        state_prior = self.temporal_encode(state_prior_prev,
                                           z,
                                           bg,
                                           prior_or_post='prior')

        # Other priors
        (z_pres_prob, z_depth_offset_loc, z_depth_offset_scale,
         z_where_offset_loc, z_where_offset_scale, z_what_offset_loc,
         z_what_offset_scale) = self.pres_depth_where_what_prior_prop(h_prior)

        z_depth_prior = Normal(z_depth_offset_loc, z_depth_offset_scale)
        z_where_prior = Normal(z_where_offset_loc, z_where_offset_scale)
        z_what_prior = Normal(z_what_offset_loc, z_what_offset_scale)

        # This is not kl divergence. This is an auxialiary loss
        kl_pres = kl_divergence_bern_bern(
            z_pres_prob, torch.full_like(z_pres_prob, self.z_pres_prior_prob))
        kl_depth = kl_divergence(z_depth_post, z_depth_prior)
        kl_depth *= z_pres
        kl_where = kl_divergence(z_where_post, z_where_prior)
        kl_where *= z_pres
        kl_what = kl_divergence(z_what_post, z_what_prior)
        kl_what *= z_pres

        # Reduced to (B,)

        # Again, this is not really kl
        kl_pres = kl_pres.flatten(start_dim=1).sum(-1)
        kl_depth = kl_depth.flatten(start_dim=1).sum(-1)
        kl_where = kl_where.flatten(start_dim=1).sum(-1)
        kl_what = kl_what.flatten(start_dim=1).sum(-1)

        assert kl_pres.size(0) == B
        kl = (kl_pres, kl_depth, kl_where, kl_what)

        return state_post, state_prior, z, kl, proposal
예제 #11
0
def log_summary(args,
                writer,
                imgs,
                y_seq,
                global_step,
                log_disc_list,
                log_prop_list,
                scalor_log_list,
                prefix='train',
                eps=1e-15):
    args = copy(args)
    if prefix == 'test':
        args.num_img_summary = args.num_img_summary * 2
    bs = imgs.size(0)
    grid_image = make_grid(imgs[:args.num_img_summary * 2].cpu().view(
        -1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/1-image', grid_image, global_step)

    grid_image = make_grid(y_seq[:args.num_img_summary * 2].cpu().view(
        -1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/2-reconstruction_overall', grid_image,
                     global_step)

    bbox_prop_list = []
    bbox_disc_list = []
    recon_prop_list = []
    recon_disc_list = []
    bg_list = []
    alpha_map_list = []
    x_mask_color_list = []
    # for each time step
    for j in range(imgs.size(1)):

        # first recon from disc and recon from
        y_each_obj = scalor_log_list[j]['y_each_obj'][:args.num_img_summary]
        importance_map_norm = scalor_log_list[j][
            'importance_map_norm'][:args.num_img_summary]

        y_prop_disc = y_each_obj * importance_map_norm

        recon_prop_list.append(y_prop_disc[:, :-args.num_cell_h *
                                           args.num_cell_w].sum(dim=1))
        recon_disc_list.append(y_prop_disc[:, -args.num_cell_h *
                                           args.num_cell_w:].sum(dim=1))
        bg_list.append(scalor_log_list[j]['bg'][:args.num_img_summary])
        alpha_map_list.append(
            scalor_log_list[j]['alpha_map'][:args.num_img_summary])
        x_mask_color_list.append(
            scalor_log_list[j]['x_mask_color'][:args.num_img_summary])

        if prefix == 'train' and not args.phase_simplify_summary:
            writer.add_histogram(
                f'{prefix}_inside_value_scalor_{j}/importance_map_norm',
                scalor_log_list[j]['importance_map_norm'][
                    scalor_log_list[j]['importance_map_norm'] > 0].cpu(
                    ).detach().numpy(), global_step)
            for k, v in scalor_log_list[j].items():
                if '_bg_' in k:
                    writer.add_histogram(
                        f'{prefix}_inside_value_scalor_{j}/{k}',
                        v.cpu().detach().numpy(), global_step)
            if args.phase_conv_lstm:
                for k, v in scalor_log_list[j].items():
                    if 'lstm' in k:
                        writer.add_histogram(
                            f'{prefix}_inside_value_scalor_{j}/{k}',
                            v.cpu().detach().numpy(), global_step)

        log_disc = {
            'z_what': log_disc_list[j]['z_what'].view(-1, 8 * 8, z_what_dim),
            'z_where_scale':
                log_disc_list[j]['z_where'].view(-1, 8 * 8, z_where_scale_dim + z_where_shift_dim)[:, :,
                :z_where_scale_dim],
            'z_where_shift':
                log_disc_list[j]['z_where'].view(-1, 8 * 8, z_where_scale_dim + z_where_shift_dim)[:, :,
                z_where_scale_dim:],
            'z_pres': log_disc_list[j]['z_pres'].permute(0, 2, 3, 1),
            'z_pres_probs': torch.sigmoid(log_disc_list[j]['z_pres_logits']).permute(0, 2, 3, 1),
            'z_what_std': log_disc_list[j]['z_what_std'].view(-1, 8 * 8, z_what_dim),
            'z_what_mean': log_disc_list[j]['z_what_mean'].view(-1, 8 * 8, z_what_dim),
            'z_where_scale_std':
                log_disc_list[j]['z_where_std'].permute(0, 2, 3, 1)[:, :, :z_where_scale_dim],
            'z_where_scale_mean':
                log_disc_list[j]['z_where_mean'].permute(0, 2, 3, 1)[:, :, :z_where_scale_dim],
            'z_where_shift_std':
                log_disc_list[j]['z_where_std'].permute(0, 2, 3, 1)[:, :, z_where_scale_dim:],
            'z_where_shift_mean':
                log_disc_list[j]['z_where_mean'].permute(0, 2, 3, 1)[:, :, z_where_scale_dim:],
            'glimpse': log_disc_list[j]['x_att'].view(-1, 8 * 8, 3, glimpse_size, glimpse_size) \
                if prefix != 'generate' else None,
            'glimpse_recon': log_disc_list[j]['y_att'].view(-1, 8 * 8, 3, glimpse_size, glimpse_size),
            'prior_z_pres_prob': log_disc_list[j]['prior_z_pres_prob'].unsqueeze(0),
            'o_each_cell': spatial_transform(log_disc_list[j]['o_att'], log_disc_list[j]['z_where'],
                                             (8 * 8 * bs, 3, img_h, img_w),
                                             inverse=True).view(-1, 8 * 8, 3, img_h, img_w),
            'alpha_hat_each_cell': spatial_transform(log_disc_list[j]['alpha_att_hat'],
                                                     log_disc_list[j]['z_where'],
                                                     (8 * 8 * bs, 1, img_h, img_w),
                                                     inverse=True).view(-1, 8 * 8, 1, img_h, img_w),
            'alpha_each_cell': spatial_transform(log_disc_list[j]['alpha_att'], log_disc_list[j]['z_where'],
                                                 (8 * 8 * bs, 1, img_h, img_w),
                                                 inverse=True).view(-1, 8 * 8, 1, img_h, img_w),
            'y_each_cell': (log_disc_list[j]['y_each_cell'] * log_disc_list[j]['z_pres'].
                            view(-1, 1, 1, 1)).view(-1, 8 * 8, 3, img_h, img_w),
            'z_depth': log_disc_list[j]['z_depth'].view(-1, 8 * 8, z_depth_dim),
            'z_depth_std': log_disc_list[j]['z_depth_std'].view(-1, 8 * 8, z_depth_dim),
            'z_depth_mean': log_disc_list[j]['z_depth_mean'].view(-1, 8 * 8, z_depth_dim),
            'z_pres_logits': log_disc_list[j]['z_pres_logits'].permute(0, 2, 3, 1),
            'z_pres_y': log_disc_list[j]['z_pres_y'].permute(0, 2, 3, 1)
        }

        bbox = visualize(
            imgs[:args.num_img_summary, j].cpu(),
            log_disc['z_pres'][:args.num_img_summary].cpu().detach(),
            log_disc['z_where_scale'][:args.num_img_summary].cpu().detach(),
            log_disc['z_where_shift'][:args.num_img_summary].cpu().detach())

        y_each_cell = log_disc['y_each_cell'].view(
            -1, 3, img_h, img_w)[:args.num_img_summary * args.num_cell_h *
                                 args.num_cell_w].cpu().detach()
        o_each_cell = log_disc['o_each_cell'].view(
            -1, 3, img_h, img_w)[:args.num_img_summary * args.num_cell_h *
                                 args.num_cell_w].cpu().detach()
        alpha_each_cell = log_disc['alpha_hat_each_cell'].view(
            -1, 1, img_h, img_w)[:args.num_img_summary * args.num_cell_h *
                                 args.num_cell_w].cpu().detach()

        if log_prop_list[j]:
            log_prop = {
                'z_what':
                log_prop_list[j]['z_what'].view(bs, -1, z_what_dim),
                'z_where_scale':
                log_prop_list[j]['z_where'].view(
                    bs, -1, z_where_scale_dim +
                    z_where_shift_dim)[:, :, :z_where_scale_dim],
                'z_where_shift':
                log_prop_list[j]['z_where'].view(
                    bs, -1,
                    z_where_scale_dim + z_where_shift_dim)[:, :,
                                                           z_where_scale_dim:],
                'z_pres':
                log_prop_list[j]['z_pres'],
                'z_what_std':
                log_prop_list[j]['z_what_std'].view(bs, -1, z_what_dim),
                'z_what_mean':
                log_prop_list[j]['z_what_mean'].view(bs, -1, z_what_dim),
                'z_where_bias_scale_std':
                log_prop_list[j]['z_where_bias_std'][:, :, :z_where_scale_dim],
                'z_where_bias_scale_mean':
                log_prop_list[j]['z_where_bias_mean']
                [:, :, :z_where_scale_dim],
                'z_where_bias_shift_std':
                log_prop_list[j]['z_where_bias_std'][:, :, z_where_scale_dim:],
                'z_where_bias_shift_mean':
                log_prop_list[j]['z_where_bias_mean'][:, :,
                                                      z_where_scale_dim:],
                'z_pres_probs':
                torch.sigmoid(log_prop_list[j]['z_pres_logits']),
                'glimpse':
                log_prop_list[j]['glimpse'],
                'glimpse_recon':
                log_prop_list[j]['glimpse_recon'],
                'prior_z_pres_prob':
                log_prop_list[j]['prior_z_pres_prob'],
                'prior_where_bias_scale_std':
                log_prop_list[j]['prior_where_bias_std']
                [:, :, :z_where_scale_dim],
                'prior_where_bias_scale_mean':
                log_prop_list[j]['prior_where_bias_mean']
                [:, :, :z_where_scale_dim],
                'prior_where_bias_shift_std':
                log_prop_list[j]['prior_where_bias_std'][:, :,
                                                         z_where_scale_dim:],
                'prior_where_bias_shift_mean':
                log_prop_list[j]['prior_where_bias_mean'][:, :,
                                                          z_where_scale_dim:],
                'lengths':
                log_prop_list[j]['lengths'],
                'z_depth':
                log_prop_list[j]['z_depth'],
                'z_depth_std':
                log_prop_list[j]['z_depth_std'],
                'z_depth_mean':
                log_prop_list[j]['z_depth_mean'],
                'y_each_obj':
                log_prop_list[j]['y_each_obj'],
                'alpha_hat_each_obj':
                log_prop_list[j]['alpha_map'],
                'z_pres_logits':
                log_prop_list[j]['z_pres_logits'],
                'z_pres_y':
                log_prop_list[j]['z_pres_y'],
                'o_each_obj':
                spatial_transform(
                    log_prop_list[j]['o_att'].view(-1, 3, glimpse_size,
                                                   glimpse_size),
                    log_prop_list[j]['z_where'].view(
                        -1, (z_where_scale_dim + z_where_shift_dim)),
                    (log_prop_list[j]['o_att'].size(1) * bs, 3, img_h, img_w),
                    inverse=True).view(bs, -1, 3, img_h, img_w),
                'z_where_bias_scale':
                log_prop_list[j]['z_where_bias'].view(
                    bs, -1, z_where_scale_dim +
                    z_where_shift_dim)[:, :, :z_where_scale_dim],
                'z_where_bias_shift':
                log_prop_list[j]['z_where_bias'].view(
                    bs, -1,
                    z_where_scale_dim + z_where_shift_dim)[:, :,
                                                           z_where_scale_dim:],
            }

            num_obj = log_prop['z_pres'].size(1)
            idx = [[], []]
            for k in range(bs):
                for l in range(int(log_prop['lengths'][k])):
                    idx[0].append(k)
                    idx[1].append(l)
            idx_false = [[], []]
            for k in range(bs):
                for l in range(num_obj - int(log_prop['lengths'][k])):
                    idx_false[0].append(k)
                    idx_false[1].append(int(log_prop['lengths'][k] + l))
            if prefix == 'train' and not args.phase_simplify_summary:
                for key, value in log_prop.items():
                    if key == 'lengths':
                        writer.add_histogram(
                            f'{prefix}_inside_value_prop_{j}/{key}',
                            value.cpu().detach().numpy(), global_step)
                    else:
                        writer.add_histogram(
                            f'{prefix}_inside_value_prop_{j}/{key}',
                            value.cpu().detach()[idx].numpy(), global_step)

            bbox_prop = visualize(
                imgs[:args.num_img_summary, j].cpu(),
                log_prop['z_pres'][:args.num_img_summary].cpu().detach(),
                log_prop['z_where_scale']
                [:args.num_img_summary].cpu().detach(),
                log_prop['z_where_shift']
                [:args.num_img_summary].cpu().detach(),
                only_bbox=True)

            bbox_prop = bbox_prop.view(args.num_img_summary, -1, 3, img_h,
                                       img_w)
            bbox_prop_one_time_step = (
                bbox_prop.sum(dim=1) +
                imgs[:args.num_img_summary, j].cpu()).clamp(0, 1)
            bbox_prop_list.append(bbox_prop_one_time_step)
        else:
            bbox_prop_one_time_step = imgs[:args.num_img_summary, j].cpu()
            bbox_prop_list.append(bbox_prop_one_time_step)
        if prefix == 'train' and not args.phase_simplify_summary:
            for key, value in log_disc.items():
                writer.add_histogram(f'{prefix}_inside_value_disc_{j}/{key}',
                                     value.cpu().detach().numpy(), global_step)

        if not args.phase_simplify_summary:
            for m in range(int(min(args.num_img_summary, bs))):

                grid_image = make_grid(
                    bbox[m * args.num_cell_h * args.num_cell_w:(m + 1) *
                         args.num_cell_h * args.num_cell_w],
                    8,
                    normalize=True,
                    pad_value=1)
                writer.add_image(f'{prefix}_disc/1-bbox_{m}_{j}', grid_image,
                                 global_step)

                grid_image = make_grid(
                    y_each_cell[m * args.num_cell_h * args.num_cell_w:(m + 1) *
                                args.num_cell_h * args.num_cell_w],
                    8,
                    normalize=True,
                    pad_value=1)
                writer.add_image(f'{prefix}_disc/2-y_each_cell_{m}_{j}',
                                 grid_image, global_step)

                grid_image = make_grid(
                    o_each_cell[m * args.num_cell_h * args.num_cell_w:(m + 1) *
                                args.num_cell_h * args.num_cell_w],
                    8,
                    normalize=True,
                    pad_value=1)
                writer.add_image(f'{prefix}_disc/3-o_each_cell_{m}_{j}',
                                 grid_image, global_step)

                grid_image = make_grid(
                    alpha_each_cell[m * args.num_cell_h *
                                    args.num_cell_w:(m + 1) * args.num_cell_h *
                                    args.num_cell_w],
                    8,
                    normalize=True,
                    pad_value=1)
                writer.add_image(
                    f'{prefix}_disc/4-alpha_hat_each_cell_{m}_{j}', grid_image,
                    global_step)

                if log_prop_list[j]:
                    bbox_prop = visualize(
                        imgs[m, j].cpu(), log_prop['z_pres'][m].cpu().detach(),
                        log_prop['z_where_scale'][m].cpu().detach(),
                        log_prop['z_where_shift'][m].cpu().detach())

                    grid_image = make_grid(bbox_prop,
                                           5,
                                           normalize=True,
                                           pad_value=1)
                    writer.add_image(f'{prefix}_prop/1-bbox_{m}_{j}',
                                     grid_image, global_step)

                    y_each_obj = log_prop['y_each_obj'][m].view(
                        -1, 3, img_h, img_w).cpu().detach()
                    grid_image = make_grid(y_each_obj,
                                           5,
                                           normalize=True,
                                           pad_value=1)
                    writer.add_image(f'{prefix}_prop/2-y_each_obj_{m}_{j}',
                                     grid_image, global_step)

                    o_each_obj = log_prop['o_each_obj'][m].view(
                        -1, 3, img_h, img_w).cpu().detach()
                    grid_image = make_grid(o_each_obj,
                                           5,
                                           normalize=True,
                                           pad_value=1)
                    writer.add_image(f'{prefix}_prop/3-o_each_obj_{m}_{j}',
                                     grid_image, global_step)

                    alpha_each_obj = log_prop['alpha_hat_each_obj'][m].view(
                        -1, 1, img_h, img_w).cpu().detach()
                    grid_image = make_grid(alpha_each_obj,
                                           5,
                                           normalize=True,
                                           pad_value=1)
                    writer.add_image(f'{prefix}_prop/4-alpha_each_obj_{m}_{j}',
                                     grid_image, global_step)

        bbox_disc = visualize(
            imgs[:args.num_img_summary, j].cpu(),
            log_disc['z_pres'][:args.num_img_summary].cpu().detach(),
            log_disc['z_where_scale'][:args.num_img_summary].cpu().detach(),
            log_disc['z_where_shift'][:args.num_img_summary].cpu().detach(),
            only_bbox=True)
        bbox_disc = bbox_disc.view(args.num_img_summary, -1, 3, img_h, img_w)
        bbox_disc = (bbox_disc.sum(dim=1) +
                     imgs[:args.num_img_summary, j].cpu()).clamp(0, 1)
        bbox_disc_list.append(bbox_disc)

    recon_disc = torch.stack(recon_disc_list, dim=1)
    grid_image = make_grid(recon_disc.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/3-reconstruction_disc', grid_image,
                     global_step)

    recon_prop = torch.stack(recon_prop_list, dim=1)
    grid_image = make_grid(recon_prop.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/4-reconstruction_prop', grid_image,
                     global_step)

    bbox_disc_all = torch.stack(bbox_disc_list, dim=1)
    grid_image = make_grid(bbox_disc_all.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/5-bbox_disc', grid_image, global_step)

    bbox_prop_all = torch.stack(bbox_prop_list, dim=1)
    grid_image = make_grid(bbox_prop_all.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/6-bbox_prop', grid_image, global_step)

    bg = torch.stack(bg_list, dim=1)
    grid_image = make_grid(bg.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/7-background', grid_image, global_step)

    alpha_map = torch.stack(alpha_map_list, dim=1)
    grid_image = make_grid(alpha_map.view(-1, 1, img_h, img_w),
                           seq_len,
                           normalize=False,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/8-alpha-map', grid_image, global_step)

    x_mask_color = torch.stack(x_mask_color_list, dim=1)
    grid_image = make_grid(x_mask_color.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=False,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/9-x-mask-color', grid_image,
                     global_step)

    return
    def forward(self,
                x,
                img_enc,
                temporal_rnn_out_pre,
                temporal_rnn_hid_pre,
                prior_rnn_out_pre,
                prior_rnn_hid_pre,
                z_what_pre,
                z_where_pre,
                z_where_bias_pre,
                z_depth_pre,
                z_pres_pre,
                cumsum_one_minus_z_pres,
                ids_pre,
                lengths,
                max_length,
                t,
                eps=1e-15):
        """

        :param x: input image (bs, c, h, w)
        :param img_enc: input image encode (bs, c, num_cell_h, num_cell_w)
        :param temporal_rnn_out_pre: (bs, max_num_obj, dim)
        :param temporal_rnn_hid_pre: (bs, max_num_obj, dim)
        :param z_what_pre: (bs, max_num_obj, dim)
        :param z_where_pre: (bs, max_num_obj, dim)
        :param z_depth_pre: (bs, max_num_obj, dim)
        :param z_pres_pre: (bs, max_num_obj, dim)
        :param cumsum_one_minus_z_pres: (bs, max_num_obj, dim)
        :param lengths: (bs)
        :return:
        """

        bs = x.size(0)
        device = x.device
        max_num_obj = max_length
        bns = bs * max_num_obj
        obj_mask = (z_pres_pre.view(bs, max_num_obj) != 0).float()

        temporal_rnn_out_pre, temporal_rnn_hid_pre, prior_rnn_out_pre, \
        prior_rnn_hid_pre, z_what_pre, z_where_pre, z_where_bias_pre, z_depth_pre, \
        z_pres_pre, cumsum_one_minus_z_pres = \
            temporal_rnn_out_pre.view(bns, -1), temporal_rnn_hid_pre.view(bns, -1), \
            prior_rnn_out_pre.view(bns, -1), prior_rnn_hid_pre.view(bns, -1), \
            z_what_pre.view(bns, -1), z_where_pre.view(bns, -1), z_where_bias_pre.view(bns, -1), \
            z_depth_pre.view(bns, -1), z_pres_pre.view(bns, -1), \
            cumsum_one_minus_z_pres.view(bns, -1)

        prior_rnn_out, prior_rnn_hid, prior_what_mean, prior_what_std, prior_where_bias_mean, \
        prior_where_bias_std, prior_depth_mean, prior_depth_std, prior_pres_prob = \
            self.prior_cell(prior_rnn_out_pre, prior_rnn_hid_pre, z_what_pre,
                            z_where_pre, z_where_bias_pre, z_depth_pre, z_pres_pre)

        z_where_att = x.new_ones(z_where_pre.size()) * .5
        z_where_att[:, 2:] = z_where_pre[:, 2:].detach()
        img_enc_att = spatial_transform(
            img_enc.unsqueeze(1).expand(-1, max_num_obj, -1, -1,
                                        -1).contiguous().view(
                                            bns, img_encode_dim,
                                            self.args.num_cell_h,
                                            self.args.num_cell_w),
            z_where_att, (bns, img_encode_dim, self.args.num_cell_h // 2,
                          self.args.num_cell_w // 2),
            inverse=False)
        # bns, dim
        temporal_img_enc = self.attention_encoding(img_enc_att).view(
            -1, temporal_img_enc_dim)
        temporal_img_enc = \
            temporal_img_enc.view(bs, -1, temporal_img_enc_dim).contiguous().view(-1, temporal_img_enc_dim)

        temporal_rnn_inp_net_inp = torch.cat([
            z_where_pre, z_pres_pre, z_what_pre, z_where_bias_pre,
            temporal_img_enc
        ],
                                             dim=1)
        temporal_rnn_inp = self.temporal_rnn_inp_net(temporal_rnn_inp_net_inp)
        # bns, dim
        temporal_rnn_out, temporal_rnn_hid = self.temporal_rnn(
            temporal_rnn_out_pre, temporal_rnn_hid_pre, temporal_rnn_inp)

        # z_where transition
        z_where_transit_bias_net_inp = torch.cat([
            temporal_rnn_out, z_what_pre, z_where_pre, z_where_bias_pre,
            temporal_img_enc
        ],
                                                 dim=1)
        z_where_bias_mean, z_where_bias_std = \
            self.z_where_transit_bias_net(z_where_transit_bias_net_inp).chunk(2, -1)
        z_where_bias_std = F.softplus(z_where_bias_std + self.z_where_std_bias)
        if self.args.phase_generate and t >= self.args.observe_frames:
            z_where_bias_dist = Normal(prior_where_bias_mean,
                                       prior_where_bias_std)
        else:
            z_where_bias_dist = Normal(z_where_bias_mean, z_where_bias_std)

        z_where_bias = z_where_bias_dist.rsample()

        z_where_shift = z_where_pre[:,
                                    2:] + self.where_update_scale * z_where_bias[:,
                                                                                 2:].tanh(
                                                                                 )

        scale, ratio = z_where_bias[:, :2].tanh().chunk(2, 1)
        scale = self.args.size_anc + self.args.var_s * scale  # add bias to let masking do its job
        ratio = self.args.ratio_anc + self.args.var_anc * ratio
        ratio_sqrt = ratio.sqrt()

        z_where = torch.cat(
            (scale / ratio_sqrt, scale * ratio_sqrt, z_where_shift), dim=1)

        # # always within the image
        z_where = torch.cat(
            (z_where[:, :2], z_where[:, 2:].clamp(-1.05, 1.05)), dim=1)

        # get glimpse encode
        x_att = \
            spatial_transform(
                x.unsqueeze(1).expand(-1, max_num_obj, -1, -1, -1).contiguous().view(bns, 3, img_h, img_w), z_where,
                (bns, 3, glimpse_size, glimpse_size), inverse=False
            )

        z_what_from_enc_mean, z_what_from_enc_std = self.z_what_net(x_att)
        z_what_from_enc_std = F.softplus(z_what_from_enc_std)

        # z_what transit
        z_what_from_temporal_mean, z_what_from_temporal_std = \
            self.z_what_from_temporal_net(temporal_rnn_out).chunk(2, -1)

        z_what_from_temporal_std = F.softplus(z_what_from_temporal_std)

        z_what_gate_net_inp = torch.cat((temporal_rnn_out, temporal_img_enc),
                                        dim=1)
        forget_gate, input_gate = self.z_what_gate_net(
            z_what_gate_net_inp).chunk(2, -1)

        z_what_mean = input_gate * z_what_from_enc_mean + \
                      forget_gate * z_what_from_temporal_mean

        z_what_std = F.softplus(input_gate * z_what_from_enc_std + \
                                forget_gate * z_what_from_temporal_std)

        if self.args.phase_generate and t >= self.args.observe_frames:
            z_what_dist = Normal(prior_what_mean, prior_what_std)
        else:
            z_what_dist = Normal(z_what_mean, z_what_std)

        z_what = z_what_dist.rsample()

        z_depth_transit_net_inp = torch.cat(
            [temporal_rnn_out, z_what, temporal_img_enc], dim=1)
        z_depth_mean, z_depth_std = self.z_depth_transit_net(
            z_depth_transit_net_inp).chunk(2, -1)
        z_depth_std = F.softplus(z_depth_std)

        if self.args.phase_generate and t >= self.args.observe_frames:
            z_depth_dist = Normal(prior_depth_mean, prior_depth_std)
        else:
            z_depth_dist = Normal(z_depth_mean, z_depth_std)

        z_depth = z_depth_dist.rsample()

        # z_pres bns, dim
        z_pres_transit_inp = torch.cat(
            [temporal_rnn_out, z_where, z_where_bias, z_what], dim=1)
        z_pres_logits = pres_logit_factor * torch.tanh(
            self.z_pres_transit(z_pres_transit_inp) + self.z_pres_logits_bias)
        if self.args.phase_generate and t >= self.args.observe_frames:
            q_z_pres = NumericalRelaxedBernoulli(probs=prior_pres_prob,
                                                 temperature=self.args.tau)
        else:
            q_z_pres = NumericalRelaxedBernoulli(logits=z_pres_logits,
                                                 temperature=self.args.tau)

        # for z_pres, we end up setting this to one during generation
        z_pres_y = q_z_pres.rsample()
        z_pres = torch.sigmoid(z_pres_y)

        cumsum_one_minus_z_pres += (1 - z_pres) * obj_mask.view(bns, 1)
        z_pres = z_pres * (cumsum_one_minus_z_pres <
                           self.z_pres_stop_threshold).float()

        # (bs, dim, glimpse_size, glimpse_size)
        o_att, alpha_att = self.glimpse_dec_net(z_what)

        alpha_att_hat = alpha_att * z_pres.view(-1, 1, 1, 1)
        y_att = alpha_att_hat * o_att

        # (bs, 3, img_h, img_w)
        y_each_obj = spatial_transform(y_att,
                                       z_where, (bns, 3, img_h, img_w),
                                       inverse=True)

        # (batch_size_t, 1, glimpse_size, glimpse_size)
        importance_map = alpha_att_hat * torch.sigmoid(-z_depth).view(
            -1, 1, 1, 1)

        # (batch_size_t, 1, img_h, img_w)
        importance_map_full_res = spatial_transform(importance_map,
                                                    z_where,
                                                    (bns, 1, img_h, img_w),
                                                    inverse=True)

        # (batch_size_t, 1, img_h, img_w)
        alpha_map = spatial_transform(alpha_att_hat,
                                      z_where, (bns, 1, img_h, img_w),
                                      inverse=True)

        kl_z_pres = \
            (calc_kl_z_pres_bernoulli(z_pres_logits, prior_pres_prob) *
             obj_mask.view(bns)).view(bs, max_num_obj).sum(1)

        prior_what_dist = Normal(prior_what_mean, prior_what_std)
        prior_where_bias_dist = Normal(prior_where_bias_mean,
                                       prior_where_bias_std)
        prior_depth_dist = Normal(prior_depth_mean, prior_depth_std)

        kl_z_what = \
            (kl_divergence(z_what_dist, prior_what_dist).sum(1) * \
             z_pres.squeeze() * obj_mask.view(bns)).view(bs, max_num_obj).sum(1)
        kl_z_where = \
            (kl_divergence(z_where_bias_dist, prior_where_bias_dist).sum(1) * \
             z_pres.squeeze() * obj_mask.view(bns)).view(bs, max_num_obj).sum(1)
        kl_z_depth = \
            (kl_divergence(z_depth_dist, prior_depth_dist).sum(1) * \
             z_pres.squeeze() * obj_mask.view(bns)).view(bs, max_num_obj).sum(1)

        ########################################### Compute log importance ############################################
        log_imp = x.new_zeros(bs, 1)
        if not self.training and self.args.phase_nll:
            z_pres_binary = (z_pres > 0.5).float()
            # (bns, dim)
            log_imp_what = (prior_what_dist.log_prob(z_what) - z_what_dist.log_prob(z_what)) * \
                           z_pres_binary * obj_mask.view(bns, 1)
            log_imp_depth = (prior_depth_dist.log_prob(z_depth) - z_depth_dist.log_prob(z_depth)) * \
                            z_pres_binary * obj_mask.view(bns, 1)
            log_imp_where = (prior_where_bias_dist.log_prob(z_where_bias) - z_where_bias_dist.log_prob(z_where_bias)) * \
                            z_pres_binary * obj_mask.view(bns, 1)

            log_pres_prior = z_pres_binary * torch.log(prior_pres_prob + eps) + \
                             (1 - z_pres_binary) * torch.log(1 - prior_pres_prob + eps)

            log_pres_pos = z_pres_binary * torch.log(torch.sigmoid(z_pres_logits) + eps) + \
                           (1 - z_pres_binary) * torch.log(1 - torch.sigmoid(z_pres_logits) + eps)

            log_imp_pres = (log_pres_prior - log_pres_pos) * obj_mask.view(
                bns, 1)

            log_imp = log_imp_what.view(bs, -1).sum(1) + log_imp_depth.view(bs, -1).sum(1) + \
                      log_imp_where.view(bs, -1).sum(1) + log_imp_pres.view(bs, -1).sum(1)

        ######################################## End of Compute log importance #########################################
        z_what_all = z_what.view(bs, max_num_obj, -1) * obj_mask.view(
            bs, max_num_obj, 1)
        z_where_dummy = x.new_ones(
            bs, max_num_obj, (z_where_scale_dim + z_where_shift_dim)) * .5
        z_where_dummy[:, :, z_where_scale_dim:] = 2
        z_where_all = z_where.view(bs, max_num_obj, -1) * obj_mask.view(bs, max_num_obj, 1) + \
                      z_where_dummy * (1 - obj_mask.view(bs, max_num_obj, 1))
        z_where_bias_all = z_where_bias.view(
            bs, max_num_obj, -1) * obj_mask.view(bs, max_num_obj, 1)
        z_pres_all = z_pres.view(bs, max_num_obj, -1) * obj_mask.view(
            bs, max_num_obj, 1)
        temporal_rnn_hid_all = \
            temporal_rnn_hid.view(bs, max_num_obj, -1) * obj_mask.view(bs, max_num_obj, 1)
        temporal_rnn_out_all = \
            temporal_rnn_out.view(bs, max_num_obj, -1) * obj_mask.view(bs, max_num_obj, 1)
        z_depth_all = z_depth.view(bs, max_num_obj, -1) * obj_mask.view(
            bs, max_num_obj, 1)
        y_each_obj_all = \
            y_each_obj.view(bs, max_num_obj, 3, img_h, img_w) * obj_mask.view(bs, max_num_obj, 1, 1, 1)
        alpha_map_all = \
            alpha_map.view(bs, max_num_obj, 1, img_h, img_w) * obj_mask.view(bs, max_num_obj, 1, 1, 1)
        importance_map_all = \
            importance_map_full_res.view(bs, max_num_obj, 1, img_h, img_w) * \
            obj_mask.view(bs, max_num_obj, 1, 1, 1)

        cumsum_one_minus_z_pres = cumsum_one_minus_z_pres.view(
            bs, max_num_obj, -1)
        prior_rnn_out = prior_rnn_out.view(bs, max_num_obj, -1)
        prior_rnn_hid = prior_rnn_hid.view(bs, max_num_obj, -1)

        if self.args.log_phase:
            self.log = {
                'z_what':
                z_what_all,
                'z_where':
                z_where_all,
                'z_pres':
                z_pres_all,
                'z_what_std':
                z_what_std.view(bs, max_num_obj, -1),
                'z_what_mean':
                z_what_mean.view(bs, max_num_obj, -1),
                'z_where_bias_std':
                z_where_bias_std.view(bs, max_num_obj, -1),
                'z_where_bias_mean':
                z_where_bias_mean.view(bs, max_num_obj, -1),
                'glimpse':
                x_att.view(bs, max_num_obj, 3, glimpse_size, glimpse_size),
                'glimpse_recon':
                y_att.view(bs, max_num_obj, 3, glimpse_size, glimpse_size),
                'prior_z_pres_prob':
                prior_pres_prob.view(bs, max_num_obj, -1),
                'prior_where_bias_std':
                prior_where_bias_std.view(bs, max_num_obj, -1),
                'prior_where_bias_mean':
                prior_where_bias_mean.view(bs, max_num_obj, -1),
                'prior_what_mean':
                prior_what_mean.view(bs, max_num_obj, -1),
                'prior_what_std':
                prior_what_std.view(bs, max_num_obj, -1),
                'lengths':
                lengths,
                'z_depth':
                z_depth_all,
                'z_depth_std':
                z_depth_std.view(bs, max_num_obj, -1),
                'z_depth_mean':
                z_depth_mean.view(bs, max_num_obj, -1),
                'y_each_obj':
                y_each_obj_all.view(bs, max_num_obj, 3, img_h, img_w),
                'alpha_map':
                alpha_map_all.view(bs, max_num_obj, 1, img_h, img_w),
                'importance_map':
                importance_map_all.view(bs, max_num_obj, 1, img_h, img_w),
                'z_pres_logits':
                z_pres_logits.view(bs, max_num_obj, -1),
                'z_pres_y':
                z_pres_y.view(bs, max_num_obj, -1),
                'o_att':
                o_att.view(bs, max_num_obj, 3, glimpse_size, glimpse_size),
                'z_where_bias':
                z_where_bias_all,
                'ids':
                ids_pre
            }
        else:
            self.log = {}

        return y_each_obj_all, alpha_map_all, importance_map_all, z_what_all, z_where_all, \
               z_where_bias_all, z_depth_all, z_pres_all, ids_pre, kl_z_what, kl_z_where, kl_z_depth, \
               kl_z_pres, temporal_rnn_out_all, temporal_rnn_hid_all, prior_rnn_out, \
               prior_rnn_hid, cumsum_one_minus_z_pres, log_imp, self.log
예제 #13
0
파일: model.py 프로젝트: JindongJiang/GNM
    def render(self, pa: List, lv_z: List) -> List:
        """

        :param pa: variables with size (bs, dim, num_cell, num_cell)
        :param lv_z: o and a with size (bs * num_cell * num_cell, dim)
        :return:
        """

        [z_pres, z_where, z_depth, _, _] = lv_z
        [o_att, a_att, bg] = pa

        bs = z_pres.size(0)

        z_pres = z_pres.permute(0, 2, 3,
                                1).reshape(bs * self.args.arch.num_cell**2, -1)
        z_where = z_where.permute(0, 2, 3,
                                  1).reshape(bs * self.args.arch.num_cell**2,
                                             -1)
        z_depth = z_depth.permute(0, 2, 3,
                                  1).reshape(bs * self.args.arch.num_cell**2,
                                             -1)

        if self.args.arch.phase_overlap == True:
            if self.args.train.phase_bg_alpha_curriculum:
                if self.args.train.bg_alpha_curriculum_period[0] < self.args.train.global_step < \
                        self.args.train.bg_alpha_curriculum_period[1]:
                    z_pres = z_pres.clamp(max=0.99)
            a_att_hat = a_att * z_pres.view(-1, 1, 1, 1)
            y_att = a_att_hat * o_att

            # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 3, img_h, img_w)
            y_each_cell = spatial_transform(
                y_att,
                z_where,
                (bs * self.args.arch.num_cell**2, self.args.data.inp_channel,
                 self.args.data.img_h, self.args.data.img_w),
                inverse=True)

            # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 1, glimpse_size, glimpse_size)
            importance_map = a_att_hat * torch.sigmoid(-z_depth).view(
                -1, 1, 1, 1)
            # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 1, img_h, img_w)
            importance_map_full_res = spatial_transform(
                importance_map,
                z_where,
                (self.args.arch.num_cell * self.args.arch.num_cell * bs, 1,
                 self.args.data.img_h, self.args.data.img_w),
                inverse=True)
            # # (bs, self.args.arch.num_cell * self.args.arch.num_cell, 1, img_h, img_w)
            importance_map_full_res = \
                importance_map_full_res.view(-1, self.args.arch.num_cell * self.args.arch.num_cell, 1,
                                             self.args.data.img_h,
                                             self.args.data.img_w)
            importance_map_full_res_norm = importance_map_full_res / \
                                           (importance_map_full_res.sum(dim=1, keepdim=True) + self.args.const.eps)

            # (bs, 3, img_h, img_w)
            y_nobg = \
                (y_each_cell.view(-1, self.args.arch.num_cell * self.args.arch.num_cell,
                                  self.args.data.inp_channel, self.args.data.img_h,
                                  self.args.data.img_w) * importance_map_full_res_norm).sum(dim=1)

            # (bs, self.args.arch.num_cell * self.args.arch.num_cell, 1, img_h, img_w)
            alpha_map = spatial_transform(
                a_att_hat,
                z_where,
                (self.args.arch.num_cell * self.args.arch.num_cell * bs, 1,
                 self.args.data.img_h, self.args.data.img_w),
                inverse=True).view(
                    -1, self.args.arch.num_cell * self.args.arch.num_cell, 1,
                    self.args.data.img_h, self.args.data.img_w).sum(dim=1)
            # (bs, 1, img_h, img_w)
            alpha_map = alpha_map + (
                alpha_map.clamp(self.args.const.eps, 1 - self.args.const.eps) -
                alpha_map).detach()

            if self.args.train.phase_bg_alpha_curriculum:
                if self.args.train.bg_alpha_curriculum_period[0] < self.args.train.global_step < \
                        self.args.train.bg_alpha_curriculum_period[1]:
                    alpha_map = alpha_map.new_ones(alpha_map.size(
                    )) * self.args.train.bg_alpha_curriculum_value
                    # y_nobg = alpha_map * y_nobg
            y = y_nobg + (1. - alpha_map) * bg
        else:
            y_att = a_att * o_att

            # (self.args.arch.num_cell * self.args.arch.num_cell * bs, 3, img_h, img_w)
            y_each_cell = spatial_transform(
                y_att,
                z_where,
                (bs * self.args.arch.num_cell**2, self.args.data.inp_channel,
                 self.args.data.img_h, self.args.data.img_w),
                inverse=True)
            y = (y_each_cell * z_pres.view(bs * self.args.arch.num_cell ** 2, 1, 1, 1)). \
                view(bs, -1, self.args.data.inp_channel, self.args.data.img_h, self.args.data.img_w).sum(dim=1)
            y_nobg = y
            alpha_map = y.new_ones(y.size(0), 1, y.size(2), y.size(3))

        pa = [y, y_nobg, alpha_map, bg]

        return pa
예제 #14
0
def log_summary(args,
                writer,
                imgs,
                y_seq,
                global_step,
                log_disc_list,
                log_prop_list,
                scalor_log_list,
                prefix='train',
                eps=1e-15):
    args = copy(args)
    if prefix == 'test':
        args.num_img_summary = args.num_img_summary
    bs = imgs.size(0)
    seq_len = imgs.shape[1]
    grid_image = make_grid(imgs[:args.num_img_summary * 2].cpu().view(
        -1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/1-image', grid_image, global_step)

    grid_image = make_grid(y_seq[:args.num_img_summary * 2].cpu().view(
        -1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/2-reconstruction_overall', grid_image,
                     global_step)

    bbox_prop_list = []
    bbox_disc_list = []
    recon_prop_list = []
    recon_disc_list = []
    bg_list = []
    alpha_map_list = []
    x_mask_color_list = []
    # for each time step
    for j in range(imgs.size(1)):

        # first recon from disc and recon from
        y_each_obj = scalor_log_list[j]['y_each_obj'][:args.num_img_summary]
        importance_map_norm = scalor_log_list[j][
            'importance_map_norm'][:args.num_img_summary]

        y_prop_disc = y_each_obj * importance_map_norm

        recon_prop_list.append(y_prop_disc[:, :-args.num_cell_h *
                                           args.num_cell_w].sum(dim=1))
        recon_disc_list.append(y_prop_disc[:, -args.num_cell_h *
                                           args.num_cell_w:].sum(dim=1))
        bg_list.append(scalor_log_list[j]['bg'][:args.num_img_summary])

        alpha_map_list.append(
            scalor_log_list[j]['alpha_map'][:args.num_img_summary])
        x_mask_color_list.append(
            scalor_log_list[j]['x_mask_color'][:args.num_img_summary])

        if prefix == 'train' and not args.phase_simplify_summary:
            writer.add_histogram(
                f'{prefix}_inside_value_scalor_{j}/importance_map_norm',
                scalor_log_list[j]['importance_map_norm'][
                    scalor_log_list[j]['importance_map_norm'] > 0].cpu(
                    ).detach().numpy(), global_step)
            for k, v in scalor_log_list[j].items():
                if '_bg_' in k:
                    writer.add_histogram(
                        f'{prefix}_inside_value_scalor_{j}/{k}',
                        v.cpu().detach().numpy(), global_step)
            if args.phase_conv_lstm:
                for k, v in scalor_log_list[j].items():
                    if 'lstm' in k:
                        writer.add_histogram(
                            f'{prefix}_inside_value_scalor_{j}/{k}',
                            v.cpu().detach().numpy(), global_step)

        num_cell_h, num_cell_w = cfg['num_cell_h'], cfg['num_cell_w']

        log_disc = {
            'z_what': log_disc_list[j]['z_what'].view(-1, num_cell_h * num_cell_w, z_what_dim),
            'z_where_scale':
                log_disc_list[j]['z_where'].view(-1, num_cell_h * num_cell_w, z_where_scale_dim + z_where_shift_dim)[:, :,
                :z_where_scale_dim],
            'z_where_shift':
                log_disc_list[j]['z_where'].view(-1, num_cell_h * num_cell_w, z_where_scale_dim + z_where_shift_dim)[:, :,
                z_where_scale_dim:],
            'z_pres': log_disc_list[j]['z_pres'].permute(0, 2, 3, 1),
            'z_pres_probs': torch.sigmoid(log_disc_list[j]['z_pres_logits']).permute(0, 2, 3, 1),
            'z_what_std': log_disc_list[j]['z_what_std'].view(-1, num_cell_h * num_cell_w, z_what_dim),
            'z_what_mean': log_disc_list[j]['z_what_mean'].view(-1, num_cell_h * num_cell_w, z_what_dim),
            'z_where_scale_std':
                log_disc_list[j]['z_where_std'].permute(0, 2, 3, 1)[:, :, :z_where_scale_dim],
            'z_where_scale_mean':
                log_disc_list[j]['z_where_mean'].permute(0, 2, 3, 1)[:, :, :z_where_scale_dim],
            'z_where_shift_std':
                log_disc_list[j]['z_where_std'].permute(0, 2, 3, 1)[:, :, z_where_scale_dim:],
            'z_where_shift_mean':
                log_disc_list[j]['z_where_mean'].permute(0, 2, 3, 1)[:, :, z_where_scale_dim:],
            'glimpse': log_disc_list[j]['x_att'].view(-1, num_cell_h * num_cell_w, 3, glimpse_size, glimpse_size) \
                if prefix != 'generate' else None,
            'glimpse_recon': log_disc_list[j]['y_att'].view(-1, num_cell_h * num_cell_w, 3, glimpse_size, glimpse_size),
            'prior_z_pres_prob': log_disc_list[j]['prior_z_pres_prob'].unsqueeze(0),
            'o_each_cell': spatial_transform(log_disc_list[j]['o_att'], log_disc_list[j]['z_where'],
                                             (num_cell_h * num_cell_w * bs, 3, img_h, img_w),
                                             inverse=True).view(-1, num_cell_h * num_cell_w, 3, img_h, img_w),
            'alpha_hat_each_cell': spatial_transform(log_disc_list[j]['alpha_att_hat'],
                                                     log_disc_list[j]['z_where'],
                                                     (num_cell_h * num_cell_w * bs, 1, img_h, img_w),
                                                     inverse=True).view(-1, num_cell_h * num_cell_w, 1, img_h, img_w),
            'alpha_each_cell': spatial_transform(log_disc_list[j]['alpha_att'], log_disc_list[j]['z_where'],
                                                 (num_cell_h * num_cell_w * bs, 1, img_h, img_w),
                                                 inverse=True).view(-1, num_cell_h * num_cell_w, 1, img_h, img_w),
            'y_each_cell': (log_disc_list[j]['y_each_cell'] * log_disc_list[j]['z_pres'].
                            view(-1, 1, 1, 1)).view(-1, num_cell_h * num_cell_w, 3, img_h, img_w),
            'z_depth': log_disc_list[j]['z_depth'].view(-1, num_cell_h * num_cell_w, z_depth_dim),
            'z_depth_std': log_disc_list[j]['z_depth_std'].view(-1, num_cell_h * num_cell_w, z_depth_dim),
            'z_depth_mean': log_disc_list[j]['z_depth_mean'].view(-1, num_cell_h * num_cell_w, z_depth_dim),
            'z_pres_logits': log_disc_list[j]['z_pres_logits'].permute(0, 2, 3, 1),
            'z_pres_y': log_disc_list[j]['z_pres_y'].permute(0, 2, 3, 1)
        }

        bbox = visualize(
            imgs[:args.num_img_summary, j].cpu(),
            log_disc['z_pres'][:args.num_img_summary].cpu().detach(),
            log_disc['z_where_scale'][:args.num_img_summary].cpu().detach(),
            log_disc['z_where_shift'][:args.num_img_summary].cpu().detach())

        y_each_cell = log_disc['y_each_cell'].view(
            -1, 3, img_h, img_w)[:args.num_img_summary * args.num_cell_h *
                                 args.num_cell_w].cpu().detach()
        o_each_cell = log_disc['o_each_cell'].view(
            -1, 3, img_h, img_w)[:args.num_img_summary * args.num_cell_h *
                                 args.num_cell_w].cpu().detach()
        alpha_each_cell = log_disc['alpha_hat_each_cell'].view(
            -1, 1, img_h, img_w)[:args.num_img_summary * args.num_cell_h *
                                 args.num_cell_w].cpu().detach()

        if log_prop_list[j]:
            log_prop = {
                'z_what':
                log_prop_list[j]['z_what'].view(bs, -1, z_what_dim),
                'z_where_scale':
                log_prop_list[j]['z_where'].view(
                    bs, -1, z_where_scale_dim +
                    z_where_shift_dim)[:, :, :z_where_scale_dim],
                'z_where_shift':
                log_prop_list[j]['z_where'].view(
                    bs, -1,
                    z_where_scale_dim + z_where_shift_dim)[:, :,
                                                           z_where_scale_dim:],
                'z_pres':
                log_prop_list[j]['z_pres'],
                'z_what_std':
                log_prop_list[j]['z_what_std'].view(bs, -1, z_what_dim),
                'z_what_mean':
                log_prop_list[j]['z_what_mean'].view(bs, -1, z_what_dim),
                'z_where_bias_scale_std':
                log_prop_list[j]['z_where_bias_std'][:, :, :z_where_scale_dim],
                'z_where_bias_scale_mean':
                log_prop_list[j]['z_where_bias_mean']
                [:, :, :z_where_scale_dim],
                'z_where_bias_shift_std':
                log_prop_list[j]['z_where_bias_std'][:, :, z_where_scale_dim:],
                'z_where_bias_shift_mean':
                log_prop_list[j]['z_where_bias_mean'][:, :,
                                                      z_where_scale_dim:],
                'z_pres_probs':
                torch.sigmoid(log_prop_list[j]['z_pres_logits']),
                #'glimpse': log_prop_list[j]['glimpse'],
                #'glimpse_recon': log_prop_list[j]['glimpse_recon'],
                #'prior_z_pres_prob': log_prop_list[j]['prior_z_pres_prob'],
                #'prior_where_bias_scale_std':
                #    log_prop_list[j]['prior_where_bias_std'][:, :, :z_where_scale_dim],
                #'prior_where_bias_scale_mean':
                #    log_prop_list[j]['prior_where_bias_mean'][:, :, :z_where_scale_dim],
                #'prior_where_bias_shift_std':
                #    log_prop_list[j]['prior_where_bias_std'][:, :, z_where_scale_dim:],
                #'prior_where_bias_shift_mean':
                #    log_prop_list[j]['prior_where_bias_mean'][:, :, z_where_scale_dim:],
                'lengths':
                log_prop_list[j]['lengths'],
                'z_depth':
                log_prop_list[j]['z_depth'],
                'z_depth_std':
                log_prop_list[j]['z_depth_std'],
                'z_depth_mean':
                log_prop_list[j]['z_depth_mean'],
                'y_each_obj':
                log_prop_list[j]['y_each_obj'],
                'alpha_hat_each_obj':
                log_prop_list[j]['alpha_map'],
                'z_pres_logits':
                log_prop_list[j]['z_pres_logits'],
                'z_pres_y':
                log_prop_list[j]['z_pres_y'],
                'o_each_obj':
                spatial_transform(
                    log_prop_list[j]['o_att'].view(-1, 3, glimpse_size,
                                                   glimpse_size),
                    log_prop_list[j]['z_where'].view(
                        -1, (z_where_scale_dim + z_where_shift_dim)),
                    (log_prop_list[j]['o_att'].size(1) * bs, 3, img_h, img_w),
                    inverse=True).view(bs, -1, 3, img_h, img_w),
                'z_where_bias_scale':
                log_prop_list[j]['z_where_bias'].view(
                    bs, -1, z_where_scale_dim +
                    z_where_shift_dim)[:, :, :z_where_scale_dim],
                'z_where_bias_shift':
                log_prop_list[j]['z_where_bias'].view(
                    bs, -1,
                    z_where_scale_dim + z_where_shift_dim)[:, :,
                                                           z_where_scale_dim:],
            }

            num_obj = log_prop['z_pres'].size(1)
            idx = [[], []]
            for k in range(bs):
                for l in range(int(log_prop['lengths'][k])):
                    idx[0].append(k)
                    idx[1].append(l)
            idx_false = [[], []]
            for k in range(bs):
                for l in range(num_obj - int(log_prop['lengths'][k])):
                    idx_false[0].append(k)
                    idx_false[1].append(int(log_prop['lengths'][k] + l))
            if prefix == 'train' and not args.phase_simplify_summary:
                for key, value in log_prop.items():
                    if key == 'lengths':
                        writer.add_histogram(
                            f'{prefix}_inside_value_prop_{j}/{key}',
                            value.cpu().detach().numpy(), global_step)
                    else:
                        writer.add_histogram(
                            f'{prefix}_inside_value_prop_{j}/{key}',
                            value.cpu().detach()[idx].numpy(), global_step)

            bbox_prop = visualize(
                imgs[:args.num_img_summary, j].cpu(),
                log_prop['z_pres'][:args.num_img_summary].cpu().detach(),
                log_prop['z_where_scale']
                [:args.num_img_summary].cpu().detach(),
                log_prop['z_where_shift']
                [:args.num_img_summary].cpu().detach(),
                only_bbox=True)

            bbox_prop = bbox_prop.view(args.num_img_summary, -1, 3, img_h,
                                       img_w)
            bbox_prop_one_time_step = (
                bbox_prop.sum(dim=1) +
                imgs[:args.num_img_summary, j].cpu()).clamp(0, 1)
            bbox_prop_list.append(bbox_prop_one_time_step)

            node_type = log_prop_list[j]['node_type']
            edge_type = log_prop_list[j]['edge_type'].reshape(
                log_prop_list[j]['node_type'].shape[0],
                log_prop_list[j]['node_type'].shape[1],
                log_prop_list[j]['node_type'].shape[1], 2)

            back_imgs = bg_list[-1]
            for idx_bg in range(back_imgs.shape[0]):
                img_cv2_format = np.uint8(
                    255 * (back_imgs[idx_bg].cpu().numpy()).transpose(
                        1, 2, 0)).copy()
                pres_idxs = [
                    i for i, e in enumerate(log_prop['z_pres'][idx_bg, :, 0])
                    if e == 1
                ]
                for idx_node in pres_idxs:
                    x_c = log_prop['z_where_shift'][idx_bg, idx_node,
                                                    0].cpu().numpy() * 32 + 32
                    y_c = log_prop['z_where_shift'][idx_bg, idx_node,
                                                    1].cpu().numpy() * 32 + 32
                    center_coordinates = (int(x_c), int(y_c))

                    if node_type[idx_bg, idx_node] > 0.5:
                        color = red
                    else:
                        color = blue

                    for idx_node_2 in pres_idxs:
                        if idx_node != idx_node_2:
                            x_c_2 = int(log_prop['z_where_shift'][
                                idx_bg, idx_node_2, 0].cpu().numpy() * 32 + 32)
                            y_c_2 = int(log_prop['z_where_shift'][
                                idx_bg, idx_node_2, 1].cpu().numpy() * 32 + 32)

                            if edge_type[idx_bg, idx_node, idx_node_2,
                                         1] > 0.5:
                                img_cv2_format = cv2.line(img_cv2_format,
                                                          center_coordinates,
                                                          (x_c_2, y_c_2),
                                                          green,
                                                          1,
                                                          lineType=cv2.LINE_AA)

                    img_cv2_format = cv2.circle(img_cv2_format,
                                                center_coordinates,
                                                3,
                                                color,
                                                -1,
                                                lineType=cv2.LINE_AA)

                back_imgs[idx_bg] = torch.tensor(
                    img_cv2_format.transpose(2, 0, 1) / 255)

            #
        else:
            bbox_prop_one_time_step = imgs[:args.num_img_summary, j].cpu()
            bbox_prop_list.append(bbox_prop_one_time_step)
        if prefix == 'train' and not args.phase_simplify_summary:
            for key, value in log_disc.items():
                writer.add_histogram(f'{prefix}_inside_value_disc_{j}/{key}',
                                     value.cpu().detach().numpy(), global_step)

        if not args.phase_simplify_summary:
            for m in range(int(min(args.num_img_summary, bs))):

                grid_image = make_grid(
                    bbox[m * args.num_cell_h * args.num_cell_w:(m + 1) *
                         args.num_cell_h * args.num_cell_w],
                    8,
                    normalize=True,
                    pad_value=1)
                writer.add_image(f'{prefix}_disc/1-bbox_{m}_{j}', grid_image,
                                 global_step)

                grid_image = make_grid(
                    y_each_cell[m * args.num_cell_h * args.num_cell_w:(m + 1) *
                                args.num_cell_h * args.num_cell_w],
                    8,
                    normalize=True,
                    pad_value=1)
                writer.add_image(f'{prefix}_disc/2-y_each_cell_{m}_{j}',
                                 grid_image, global_step)

                grid_image = make_grid(
                    o_each_cell[m * args.num_cell_h * args.num_cell_w:(m + 1) *
                                args.num_cell_h * args.num_cell_w],
                    8,
                    normalize=True,
                    pad_value=1)
                writer.add_image(f'{prefix}_disc/3-o_each_cell_{m}_{j}',
                                 grid_image, global_step)

                grid_image = make_grid(
                    alpha_each_cell[m * args.num_cell_h *
                                    args.num_cell_w:(m + 1) * args.num_cell_h *
                                    args.num_cell_w],
                    8,
                    normalize=True,
                    pad_value=1)
                writer.add_image(
                    f'{prefix}_disc/4-alpha_hat_each_cell_{m}_{j}', grid_image,
                    global_step)

                if log_prop_list[j]:
                    bbox_prop = visualize(
                        imgs[m, j].cpu(), log_prop['z_pres'][m].cpu().detach(),
                        log_prop['z_where_scale'][m].cpu().detach(),
                        log_prop['z_where_shift'][m].cpu().detach())

                    grid_image = make_grid(bbox_prop,
                                           5,
                                           normalize=True,
                                           pad_value=1)
                    writer.add_image(f'{prefix}_prop/1-bbox_{m}_{j}',
                                     grid_image, global_step)

                    y_each_obj = log_prop['y_each_obj'][m].view(
                        -1, 3, img_h, img_w).cpu().detach()
                    grid_image = make_grid(y_each_obj,
                                           5,
                                           normalize=True,
                                           pad_value=1)
                    writer.add_image(f'{prefix}_prop/2-y_each_obj_{m}_{j}',
                                     grid_image, global_step)

                    o_each_obj = log_prop['o_each_obj'][m].view(
                        -1, 3, img_h, img_w).cpu().detach()
                    grid_image = make_grid(o_each_obj,
                                           5,
                                           normalize=True,
                                           pad_value=1)
                    writer.add_image(f'{prefix}_prop/3-o_each_obj_{m}_{j}',
                                     grid_image, global_step)

                    alpha_each_obj = log_prop['alpha_hat_each_obj'][m].view(
                        -1, 1, img_h, img_w).cpu().detach()
                    grid_image = make_grid(alpha_each_obj,
                                           5,
                                           normalize=True,
                                           pad_value=1)
                    writer.add_image(f'{prefix}_prop/4-alpha_each_obj_{m}_{j}',
                                     grid_image, global_step)

        bbox_disc = visualize(
            imgs[:args.num_img_summary, j].cpu(),
            log_disc['z_pres'][:args.num_img_summary].cpu().detach(),
            log_disc['z_where_scale'][:args.num_img_summary].cpu().detach(),
            log_disc['z_where_shift'][:args.num_img_summary].cpu().detach(),
            only_bbox=True)
        bbox_disc = bbox_disc.view(args.num_img_summary, -1, 3, img_h, img_w)
        bbox_disc = (bbox_disc.sum(dim=1) +
                     imgs[:args.num_img_summary, j].cpu()).clamp(0, 1)
        bbox_disc_list.append(bbox_disc)

    recon_disc = torch.stack(recon_disc_list, dim=1)
    grid_image = make_grid(recon_disc.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/3-reconstruction_disc', grid_image,
                     global_step)

    recon_prop = torch.stack(recon_prop_list, dim=1)
    grid_image = make_grid(recon_prop.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/4-reconstruction_prop', grid_image,
                     global_step)

    bbox_disc_all = torch.stack(bbox_disc_list, dim=1)
    grid_image = make_grid(bbox_disc_all.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/5-bbox_disc', grid_image, global_step)

    bbox_prop_all = torch.stack(bbox_prop_list, dim=1)
    grid_image = make_grid(bbox_prop_all.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/6-bbox_prop', grid_image, global_step)

    bg = torch.stack(bg_list, dim=1)
    grid_image = make_grid(bg.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=True,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/7-background', grid_image, global_step)

    alpha_map = torch.stack(alpha_map_list, dim=1)
    grid_image = make_grid(alpha_map.view(-1, 1, img_h, img_w),
                           seq_len,
                           normalize=False,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/8-alpha-map', grid_image, global_step)

    x_mask_color = torch.stack(x_mask_color_list, dim=1)
    grid_image = make_grid(x_mask_color.view(-1, 3, img_h, img_w),
                           seq_len,
                           normalize=False,
                           pad_value=1)
    writer.add_image(f'{prefix}_scalor/9-x-mask-color', grid_image,
                     global_step)

    return