示例#1
0
    def forward(self, objs, layout_boxes, layout_masks, test_mode=False):
        obj_vecs = self.attribute_embedding.forward(objs)  # [B, N, d']
        seg_batches = []
        for b in range(obj_vecs.size(0)):
            mask = remove_dummy_objects(objs[b], self.opt.vocab)
            objs_vecs_batch = obj_vecs[b][mask]
            layout_boxes_batch = layout_boxes[b][mask]
            # Masks Layout
            if layout_masks is not None:
                layout_masks_batch = layout_masks[b][mask]
                seg = masks_to_layout(objs_vecs_batch,
                                      layout_boxes_batch,
                                      layout_masks_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0],
                                      test_mode=test_mode)
            else:
                # Boxes Layout
                seg = boxes_to_layout(objs_vecs_batch, layout_boxes_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0])
            seg_batches.append(seg)
        seg = torch.cat(seg_batches, dim=0)

        # we downsample segmap and run convolution
        x = F.interpolate(seg, size=(self.sh, self.sw))
        x = self.fc(x)
        x = self.head_0(x, seg)
        x = self.up(x)
        x = self.G_middle_0(x, seg)

        if self.opt.num_upsampling_layers == 'more' or \
                self.opt.num_upsampling_layers == 'most':
            x = self.up(x)

        x = self.G_middle_1(x, seg)

        x = self.up(x)
        x = self.up_0(x, seg)
        x = self.up(x)
        x = self.up_1(x, seg)
        x = self.up(x)
        x = self.up_2(x, seg)
        x = self.up(x)
        x = self.up_3(x, seg)

        if self.opt.num_upsampling_layers == 'most':
            x = self.up(x)
            x = self.up_4(x, seg)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)

        return x
示例#2
0
    def forward(self,
                img,
                objs,
                layout_boxes,
                layout_masks=None,
                gt_train=True,
                fool=False):
        obj_vecs = self.attribute_embedding.forward(objs)  # [B, N, d']

        # Masks Layout
        seg_batches = []
        for b in range(obj_vecs.size(0)):
            mask = remove_dummy_objects(objs[b], self.opt.vocab)
            objs_vecs_batch = obj_vecs[b][mask]
            layout_boxes_batch = layout_boxes[b][mask]

            # Masks Layout
            if layout_masks is not None:
                layout_masks_batch = layout_masks[b][mask]
                seg = masks_to_layout(
                    objs_vecs_batch,
                    layout_boxes_batch,
                    layout_masks_batch,
                    self.opt.image_size[0],
                    self.opt.image_size[0],
                    test_mode=False)  # test mode always false in disc.
            else:
                # Boxes Layout
                seg = boxes_to_layout(objs_vecs_batch, layout_boxes_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0])
            seg_batches.append(seg)

        # layout = torch.cat(layout_batches, dim=0)  # [B, N, d']
        seg = torch.cat(seg_batches, dim=0)
        input = torch.cat([img, seg], dim=1)

        result = []
        get_intermediate_features = not self.opt.no_ganFeat_loss
        for name, D in self.named_children():
            if name.startswith('discriminator'):
                out = D(input)
                if not get_intermediate_features:
                    out = [out]
                result.append(out)
                input = self.downsample(input)
        return result
示例#3
0
    def forward(self,
                objs,
                triples,
                obj_to_img=None,
                boxes_gt=None,
                masks_gt=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)  # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]]  # Now have shape (T,)
        edges = torch.stack([s, o], dim=1)  # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs = self.obj_embeddings(
            objs)  # 'objs' => indices for model.vocab['object_idx_to_name']
        obj_vecs_orig = obj_vecs
        pred_vecs = self.pred_embeddings(
            p)  #  'p' => indices for model.vocab['pred_idx_to_name']

        if isinstance(self.gconv, nn.Linear):
            obj_vecs = self.gconv(obj_vecs)
        else:
            obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        if self.gconv_net is not None:
            obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        # bounding box prediction
        boxes_pred_info = None
        if self.use_bbox_info:
            # bounding box prediction + predicted box info
            boxes_pred_info = self.box_net(obj_vecs)
            boxes_pred = boxes_pred_info[:, 0:
                                         4]  # first 4 entries are bbox coords
        else:
            boxes_pred = self.box_net(obj_vecs)

        masks_pred = None
        layout_masks = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        # this only affects training if loss is non-zero
        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o]
        s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        # uses predicted subject/object boxes, original subject/object embedding (input to GCNN)
        ## use untrained embedding vectors
        ##rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs_pred, o_vecs_pred],
                                  dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        # concatenate triplet vectors
        s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o]
        triplet_input = torch.cat([s_vecs_pred, pred_vecs, o_vecs_pred], dim=1)

        # triplet bounding boxes
        triplet_boxes_pred = None
        if self.triplet_box_net is not None:
            # predict 8 point bounding boxes
            triplet_boxes_pred = self.triplet_box_net(triplet_input)

        # triplet binary masks
        triplet_masks_pred = None
        if self.triplet_mask_net is not None:
            # input dimension must be [h, w, 1, 1]
            triplet_mask_scores = self.triplet_mask_net(triplet_input[:, :,
                                                                      None,
                                                                      None])
            # only used for binary/masks CE loss
            #triplet_masks_pred = triplet_mask_scores.squeeze(1).sigmoid()
            triplet_masks_pred = triplet_mask_scores.squeeze(1)

        # triplet embedding
        triplet_embed = None
        if self.triplet_embed_net is not None:
            triplet_embed = self.triplet_embed_net(triplet_input)

        # triplet superbox
        triplet_superboxes_pred = None
        if self.triplet_superbox_net is not None:
            # predict 8 point bounding boxes
            triplet_superboxes_pred = self.triplet_superbox_net(
                triplet_input)  # s/p/o (bboxes?)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        # compose layout mask
        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                     obj_to_img, H, W)

        layout_crn = layout
        sg_context_pred = None
        sg_context_pred_d = None
        if self.sg_context_net is not None:
            N, C, H, W = layout.size()
            context = sg_context_to_layout(obj_vecs,
                                           obj_to_img,
                                           pooling=self.gcnn_pooling)
            sg_context_pred_sqz = self.sg_context_net(context)

            #### vector to spatial replication
            b = N
            s = self.sg_context_dim
            # b, s = sg_context_pred_sqz.size()
            sg_context_pred = sg_context_pred_sqz.view(b, s, 1, 1).expand(
                b, s, layout.size(2), layout.size(3))
            layout_crn = torch.cat([layout, sg_context_pred], dim=1)

            ## discriminator uses different FC layer than the generator
            sg_context_predd_sqz = self.sg_context_net_d(context)
            s = self.sg_context_dim_d
            sg_context_pred_d = sg_context_predd_sqz.view(b, s, 1, 1).expand(
                b, s, layout.size(2), layout.size(3))

        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape,
                                       dtype=layout.dtype,
                                       device=layout.device)
            layout_crn = torch.cat([layout_crn, layout_noise], dim=1)

        # layout model only
        #img = self.refinement_net(layout_crn)
        img = None

        # compose triplet boxes using 'triplets', objs, etc.
        if boxes_gt is not None:
            s_boxes_gt, o_boxes_gt = boxes_gt[s], boxes_gt[o]
            triplet_boxes_gt = torch.cat([s_boxes_gt, o_boxes_gt], dim=1)
        else:
            triplet_boxes_gt = None

        #return img, boxes_pred, masks_pred, rel_scores
        return img, boxes_pred, masks_pred, objs, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, rel_scores, obj_vecs, pred_vecs, triplet_boxes_pred, triplet_boxes_gt, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred
示例#4
0
    def forward(self,
                obj_to_img,
                boxes_gt,
                obj_fmaps,
                mask_noise_indexes=None,
                masks_gt=None,
                bg_layout=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    """
        assert boxes_gt.max() < 1.1 and boxes_gt.min(
        ) > -0.1, "boxes_gt should be within range [0,1]"
        # O, T = objs.size(0), triples.size(0)
        # s, p, o = triples.chunk(3, dim=1)           # All have shape (T, 1)
        # s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,)
        # edges = torch.stack([s, o], dim=1)          # Shape is (T, 2)
        #
        # if obj_to_img is None:
        #   obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)
        #
        # obj_vecs = self.obj_embeddings(objs)
        # obj_vecs_orig = obj_vecs
        # pred_vecs = self.pred_embeddings(p)
        #
        # if isinstance(self.gconv, nn.Linear):
        #   obj_vecs = self.gconv(obj_vecs)
        # else:
        #   obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        # if self.gconv_net is not None:
        #   obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)
        if self.args.not_decrease_feature_dimension:
            obj_vecs = obj_fmaps
        else:
            obj_vecs = self.obj_fmap_net(obj_fmaps)
        no_noise_obj_vecs = obj_vecs
        if self.args.object_noise_dim > 0:
            # select objs belongs to images in mask_noise_indexes
            if mask_noise_indexes is not None and self.training:
                mask_noise_obj_index_list = []
                for ind in mask_noise_indexes:
                    mask_noise_obj_index_list.append(
                        (obj_to_img == ind).nonzero())
                mask_noise_obj_indexes = torch.cat(mask_noise_obj_index_list,
                                                   dim=0)[:, 0]

            if self.args.noise_apply_method == "concat":
                object_noise = torch.randn(
                    (obj_vecs.shape[0], self.args.object_noise_dim),
                    dtype=obj_vecs.dtype,
                    device=obj_vecs.device)
                if mask_noise_indexes is not None and self.training:
                    object_noise[mask_noise_obj_indexes] = 0
                obj_vecs = torch.cat([obj_vecs, object_noise], dim=1)
            elif self.args.noise_apply_method == "add":
                object_noise = torch.randn(obj_vecs.shape,
                                           dtype=obj_vecs.dtype,
                                           device=obj_vecs.device)
                if mask_noise_indexes is not None and self.training:
                    object_noise[mask_noise_obj_indexes] = 0
                obj_vecs = obj_vecs + object_noise

        # boxes_pred = self.box_net(obj_vecs)

        masks_pred = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(
                obj_vecs.view(obj_vecs.shape[0], -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        # s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        # s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        # rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        # rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        # layout_boxes = boxes_pred if boxes_gt is None else boxes_gt
        layout_boxes = boxes_gt

        # layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)

        if masks_pred is None:
            if self.args.object_no_noise_with_bbox:
                layout = boxes_to_layout(no_noise_obj_vecs, layout_boxes,
                                         obj_to_img, H, W)
            else:
                layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H,
                                         W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            if self.args.object_no_noise_with_mask:
                layout = masks_to_layout(no_noise_obj_vecs, layout_boxes,
                                         layout_masks, obj_to_img, H, W)
            else:
                layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                         obj_to_img, H, W)
        ret_layout = layout

        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            if self.args.noise_apply_method == "concat":
                noise_shape = (N, self.layout_noise_dim, H, W)
            elif self.args.noise_apply_method == "add":
                noise_shape = layout.shape
            # print("check noise_std here, it is %.10f" % self.args.noise_std)
            noise_std = torch.zeros(noise_shape,
                                    dtype=layout.dtype,
                                    device=layout.device).fill_(
                                        self.args.noise_std)
            layout_noise = torch.normal(mean=0.0, std=noise_std)
            if self.args.layout_noise_only_on_foreground:
                layout_noise *= (1 - bg_layout[:, :1, :, :].repeat(
                    1, self.layout_noise_dim, 1, 1))

            if mask_noise_indexes is not None and self.training:
                layout_noise[mask_noise_indexes] = 0.
            # layout_noise = torch.randn(noise_shape, dtype=layout.dtype,
            #                            device=layout.device)
            if self.args.noise_apply_method == "concat":
                layout = torch.cat([layout, layout_noise], dim=1)
            elif self.args.noise_apply_method == "add":
                layout = layout + layout_noise
        img = self.refinement_net(layout)
        return img, ret_layout
示例#5
0
    def forward(self, objs, triples, obj_to_img=None, pred_to_img=None,
              boxes_gt=None, masks_gt=None):
        """
        Required Inputs:
        - objs: LongTensor of shape (O,) giving categories for all objects
        - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
          means that there is a triple (objs[s], p, objs[o])

        Optional Inputs:
        - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
          means that objects[o] is an object in image i. If not given then
          all objects are assumed to belong to the same image.
        - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
          the spatial layout; if not given then use predicted boxes.
        """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)           # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,)
        edges = torch.stack([s, o], dim=1)          # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs, pred_vecs = self.embedding(objs, p)
        obj_vecs_orig = obj_vecs

        obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        boxes_pred = self.box_net(obj_vecs)
        
        masks_pred = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                   obj_to_img, H, W)

        # Add context embedding
#         context = self.context_network(pred_vecs)
        # TODO how to concatenate this?

#         if self.layout_noise_dim > 0:
#             N, C, H, W = layout.size()
#             # Concatenate noise with new context embedding and make proper shape
#             noise = torch.randn(N, self.layout_noise_dim)
#             noise = noise.view(noise.size(0), self.layout_noise_dim)
#             z = torch.cat([noise,proj_c],1)
#             layout_noise = self.noise_layout(z)
#             layout = torch.cat([layout, layout_noise], dim=1)
        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape, dtype=layout.dtype,
                                     device=layout.device)
            layout = torch.cat([layout, layout_noise], dim=1)

        img = self.refinement_net(layout)
        return img, boxes_pred, masks_pred, rel_scores
示例#6
0
    def forward(self,
                objs,
                triples,
                obj_to_img=None,
                boxes_gt=None,
                masks_gt=None,
                tr_to_img=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)  # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]]  # Now have shape (T,)
        edges = torch.stack([s, o], dim=1)  # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs = self.obj_embeddings(
            objs)  # 'objs' => indices for model.vocab['object_idx_to_name']
        obj_vecs_orig = obj_vecs
        pred_vecs = self.pred_embeddings(
            p)  #  'p' => indices for model.vocab['pred_idx_to_name']
        pred_vecs_orig = pred_vecs

        if isinstance(self.gconv, nn.Linear):
            obj_vecs = self.gconv(obj_vecs)
        else:
            obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        if self.gconv_net is not None:
            obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        #### object context vectors ###############
        num_imgs = obj_to_img[obj_to_img.size(0) - 1] + 1
        context_obj_vecs = torch.zeros(num_imgs,
                                       obj_vecs.size(1),
                                       dtype=obj_vecs.dtype,
                                       device=obj_vecs.device)
        obj_to_img_exp = obj_to_img.view(-1, 1).expand_as(obj_vecs)
        context_obj_vecs = context_obj_vecs.scatter_add(
            0, obj_to_img_exp, obj_vecs)

        # get object counts
        obj_counts = torch.zeros(num_imgs,
                                 dtype=obj_vecs.dtype,
                                 device=obj_vecs.device)
        ones = torch.ones(obj_to_img.size(0),
                          dtype=obj_vecs.dtype,
                          device=obj_vecs.device)
        obj_counts = obj_counts.scatter_add(0, obj_to_img, ones)
        context_obj_vecs = context_obj_vecs / obj_counts.view(-1, 1)
        context_obj_vecs = context_obj_vecs[obj_to_img]
        context_obj_vecs = context_obj_vecs[s]
        ####################################

        ####### triplet context vectors ###########
        #context_tr_vecs = None
        # concatenate triplet vectors
        #triplets = torch.cat([obj_vecs[s], pred_vecs, obj_vecs[o]], dim=1)
        #context_tr_vecs = torch.zeros(num_imgs, triplets.size(1), dtype=obj_vecs.dtype, device=obj_vecs.device)
        # need triplet to image
        #tr_to_img_exp = tr_to_img.view(-1, 1).expand_as(triplets)
        #context_tr_vecs = context_tr_vecs.scatter_add(0, tr_to_img_exp, triplets)
        # get triplet counts
        #tr_counts = torch.zeros(num_imgs, dtype=obj_vecs.dtype, device=obj_vecs.device)
        #ones = torch.ones(triplets.size(0), dtype=obj_vecs.dtype, device=obj_vecs.device)
        #tr_counts = tr_counts.scatter_add(0, tr_to_img, ones)
        #context_tr_vecs = context_tr_vecs/tr_counts.view(-1,1)
        # dimension is (# triplets, 3*input_dim)
        #context_tr_vecs = context_tr_vecs[tr_to_img]
        # get some context!
        #context_tr_vecs = self.triplet_context_net(context_tr_vecs)
        ###########################################

        ####  mask out some predicates #####
        pred_mask_gt = None
        pred_mask_scores = None
        if self.use_masked_sg:
            perc = torch.FloatTensor([0.50])  # hyperparameter
            num_mask_objs = torch.floor(perc *
                                        len(s)).cpu().numpy()[0].astype(int)
            if num_mask_objs < 1:
                num_mask_objs = 1
            mask_idx = torch.randint(0, len(s) - 1, (num_mask_objs, ))
            #rand_idx = torch.randperm(len(s)-1)
            #mask_idx = rand_idx[:num_mask_objs]
            # GT
            pred_mask_gt = p[mask_idx.long()]  # return
            # set mask idx to masked embedding (e.g. new SG!)
            pred_vecs_copy = pred_vecs_orig
            ##### need to add i=46 None embedding
            pred_vecs_copy[mask_idx.long()] = self.pred_embeddings(
                torch.tensor([self.mask_pred]).cuda())

            # convolve new masked SG
            if isinstance(self.gconv, nn.Linear):
                mask_obj_vecs = self.gconv(obj_vecs_orig)
            else:
                mask_obj_vecs, mask_pred_vecs = self.gconv(
                    obj_vecs_orig, pred_vecs_copy, edges)
            if self.gconv_net is not None:
                mask_obj_vecs, mask_pred_vecs = self.gconv_net(
                    mask_obj_vecs, mask_pred_vecs, edges)

            # subj/obj obj idx
            s_mask = s[mask_idx.long()]
            o_mask = o[mask_idx.long()]

            subj_vecs_mask = mask_obj_vecs[s_mask]
            obj_vecs_mask = mask_obj_vecs[o_mask]

            # predict masked predicate relationship
            pred_mask_input = torch.cat([subj_vecs_mask, obj_vecs_mask], dim=1)
            pred_mask_scores = self.pred_mask_net(pred_mask_input)
        #####################

        # bounding box prediction
        boxes_pred_info = None
        if self.use_bbox_info:
            # bounding box prediction + predicted box info
            boxes_pred_info = self.box_net(obj_vecs)
            boxes_pred = boxes_pred_info[:, 0:
                                         4]  # first 4 entries are bbox coords
        else:
            boxes_pred = self.box_net(obj_vecs)

        masks_pred = None
        layout_masks = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        # predicted bboxes and embedding vectors
        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o]
        # input embedding vectors
        s_vecs, o_vecs, p_vecs = obj_vecs_orig[s], obj_vecs_orig[
            o], pred_vecs_orig
        input_tr_vecs = torch.cat([s_vecs, p_vecs, o_vecs], dim=1)

        # VSA (with obj/pred vectors of varying kinds)
        fr_obj_vecs = self.fr_obj_embeddings(objs)
        fr_pred_vecs = self.fr_pred_embeddings(p)
        fr_s_vecs, fr_o_vecs = fr_obj_vecs[s], fr_obj_vecs[o]
        mapc_bind = fr_s_vecs * fr_o_vecs * fr_pred_vecs
        # mapc_bind = s_vecs * o_vecs * p_vecs
        #mapc_bind = s_vecs_pred * o_vecs_pred * pred_vecs
        mapc_bind = F.normalize(mapc_bind, p=2, dim=1)

        # uses predicted subject/object boxes, original subject/object embedding (input to GCNN)
        ## use original embedding vectors
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        # subject prediction
        subj_aux_input = torch.cat([s_boxes, o_boxes, p_vecs, o_vecs], dim=1)
        subj_scores = self.subj_aux_net(subj_aux_input)

        # object prediction
        obj_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, p_vecs], dim=1)
        obj_scores = self.obj_aux_net(obj_aux_input)

        # object class prediction (for output object vectors)
        obj_class_scores = self.obj_class_aux_net(obj_vecs)

        # relationship class prediction (for output object vectors)
        # relationship embedding (very small embedding)
        # augment relationship embedding
        use_augmentation = False
        mask_rel_embedding = None
        if use_augmentation:
            num_augs = 4
            num_preds = len(p)
            mask_rel_embedding = torch.zeros(
                (num_preds, num_augs + 1, self.embedding_dim),
                dtype=obj_vecs.dtype,
                device=obj_vecs.device)
            rel_embedding = []
            #pred_vecs_mask = np.zeros((num_preds,num_augs,self.embedding_dim))
            # mask embedding perc_aug of vectors with 0
            perc_aug = torch.FloatTensor([0.4])  # hyperparameter
            num_mask = torch.floor(
                perc_aug * len(pred_vecs[0])).cpu().numpy()[0].astype(int)
            #p_ids = []
            for i in range(num_preds):
                pred = pred_vecs[i]
                #p_id = p[i]
                vecs = [pred]
                #p_ids += [p_id]
                for j in range(num_augs):
                    # pick a random set of indices to zero-out
                    rand_idx = torch.randperm(len(
                        pred_vecs[0]))  # 0-127 range shuffled
                    mask_idx = rand_idx[:num_mask]
                    pred_mask = pred.detach().clone()
                    pred_mask[mask_idx] = 0.0
                    vecs += [pred_mask]
                # project masked augmented relationship vectors
                pred_mask = self.rel_embed_aux_net(
                    torch.stack(vecs))  # output predicate embeddings
                pred_mask = F.normalize(pred_mask, dim=1)
                mask_rel_embedding[i, :, :] = pred_mask
                rel_embedding += [pred_mask[0]]

        #rel_embedding = torch.stack(rel_embedding)
        # projection head for supervised contrastive loss
        rel_embedding = self.rel_embed_aux_net(
            pred_vecs)  # output projected predicate embeddings
        rel_embedding = F.normalize(rel_embedding, dim=1)

        # relationship class prediction on predicates
        rel_class_scores = self.rel_class_aux_net(pred_vecs)

        # concatenate triplet vectors
        s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o]
        triplet_input = torch.cat([s_vecs_pred, pred_vecs, o_vecs_pred], dim=1)

        # triplet bounding boxes
        triplet_boxes_pred = None
        if self.triplet_box_net is not None:
            # predict 8 point bounding boxes
            triplet_boxes_pred = self.triplet_box_net(triplet_input)

        # triplet binary masks
        triplet_masks_pred = None
        if self.triplet_mask_net is not None:
            # input dimension must be [h, w, 1, 1]
            triplet_mask_scores = self.triplet_mask_net(triplet_input[:, :,
                                                                      None,
                                                                      None])
            # only used for binary/masks CE loss
            #triplet_masks_pred = triplet_mask_scores.squeeze(1).sigmoid()
            triplet_masks_pred = triplet_mask_scores.squeeze(1)

        # triplet embedding
        triplet_embed = None
        if self.triplet_embed_net is not None:
            triplet_embed = self.triplet_embed_net(triplet_input)

        # triplet superbox
        triplet_superboxes_pred = None
        if self.triplet_superbox_net is not None:
            # predict 2 point superboxes
            triplet_superboxes_pred = self.triplet_superbox_net(
                triplet_input)  # s/p/o (bboxes?)

        # predicate grounding
        pred_ground = None
        if self.pred_ground_net is not None:
            # predict 2 point pred grounding
            pred_ground = self.pred_ground_net(pred_vecs)  # s/p/o (bboxes?)

        # triplet context
        triplet_context_input = torch.cat(
            [context_obj_vecs, s_vecs_pred, pred_vecs, o_vecs_pred], dim=1)
        # output dimension is 384
        context_tr_vecs = self.triplet_context_net(triplet_context_input)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        # compose layout mask
        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                     obj_to_img, H, W)
        layout_crn = layout
        sg_context_pred = None
        sg_context_pred_d = None
        if self.sg_context_net is not None:
            N, C, H, W = layout.size()
            context = sg_context_to_layout(obj_vecs,
                                           obj_to_img,
                                           pooling=self.gcnn_pooling)
            sg_context_pred_sqz = self.sg_context_net(context)

            #### vector to spatial replication
            b = N
            s = self.sg_context_dim
            # b, s = sg_context_pred_sqz.size()
            sg_context_pred = sg_context_pred_sqz.view(b, s, 1, 1).expand(
                b, s, layout.size(2), layout.size(3))
            layout_crn = torch.cat([layout, sg_context_pred], dim=1)

            ## discriminator uses different FC layer than the generator
            sg_context_predd_sqz = self.sg_context_net_d(context)
            s = self.sg_context_dim_d
            sg_context_pred_d = sg_context_predd_sqz.view(b, s, 1, 1).expand(
                b, s, layout.size(2), layout.size(3))

        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape,
                                       dtype=layout.dtype,
                                       device=layout.device)
            layout_crn = torch.cat([layout_crn, layout_noise], dim=1)

        # layout model only
        #img = self.refinement_net(layout_crn)
        img = None

        # compose triplet boxes using 'triplets', objs, etc.
        if boxes_gt is not None:
            s_boxes_gt, o_boxes_gt = boxes_gt[s], boxes_gt[o]
            triplet_boxes_gt = torch.cat([s_boxes_gt, o_boxes_gt], dim=1)
        else:
            triplet_boxes_gt = None

        #return img, boxes_pred, masks_pred, rel_scores
        return img, boxes_pred, masks_pred, objs, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, rel_scores, obj_vecs, pred_vecs, triplet_boxes_pred, triplet_boxes_gt, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred, obj_scores, pred_mask_gt, pred_mask_scores, context_tr_vecs, input_tr_vecs, obj_class_scores, rel_class_scores, subj_scores, rel_embedding, mask_rel_embedding, pred_ground  #, mapc_bind
示例#7
0
    def forward(self,
                objs,
                triples,
                obj_to_img=None,
                boxes_gt=None,
                masks_gt=None):
        """ 
        Required Inputs:
        - objs: LongTensor of shape (O,) giving categories for all objects
        - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
          means that there is a triple (objs[s], p, objs[o])
        
        main Process: graph >>> graph conv >>> layout >>> CRN.

        Optional Inputs:
        - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
          means that objects[o] is an object in image i. If not given then
          all objects are assumed to belong to the same image.
        - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
          the spatial layout; if not given then use predicted boxes.
        """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)  # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]]  # Now have shape (T,)
        # edges specify start and end.
        edges = torch.stack([s, o], dim=1)  # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs = self.obj_embeddings(objs)
        obj_vecs_orig = obj_vecs
        pred_vecs = self.pred_embeddings(p)

        # Graph convolutional network.
        if isinstance(self.gconv, nn.Linear):
            obj_vecs = self.gconv(obj_vecs)
        else:
            # what's the difference between gconv and gconv_net
            obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        if self.gconv_net is not None:
            obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        boxes_pred = self.box_net(obj_vecs)

        masks_pred = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        # generate scene layout from bounding box and masks.
        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                     obj_to_img, H, W)

        # TODO:??? what is  layout_noise_dim
        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape,
                                       dtype=layout.dtype,
                                       device=layout.device)
            layout = torch.cat([layout, layout_noise], dim=1)
        # TODO: layout is bbox or segmentation mask ??
        img = self.refinement_net(layout)
        return img, boxes_pred, masks_pred, rel_scores
示例#8
0
    def forward(self,
                objs,
                triples,
                obj_to_img=None,
                pred_to_img=None,
                boxes_gt=None,
                masks_gt=None,
                lstm_hidden=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    - lstm_hidden: Tensor of shape (N, self.lstm_hid_dim)
    """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)  # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]]  # Now have shape (T,)
        edges = torch.stack([s, o], dim=1)  # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs = self.obj_embeddings(objs)
        obj_vecs_orig = obj_vecs
        pred_vecs = self.pred_embeddings(p)

        if isinstance(self.gconv, nn.Linear):
            obj_vecs = self.gconv(obj_vecs)
        else:
            obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        if self.gconv_net is not None:
            obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        # Bounding boxes should be conditioned on context
        # because layout is finalized at this step
        context = None
        if self.context_network is not None:
            context, embedding = self.context_network(pred_vecs, pred_to_img)
            # Concatenate global context to each object depending on which image it is from
            # Probably not an efficient way to do this
            obj_with_context = torch.stack([
                torch.cat((obj_vecs[i], embedding[obj_to_img[i].item()]))
                for i in range(O)
            ])
            boxes_pred = self.box_net(obj_with_context)

            masks_pred = None
            if self.mask_net is not None:
                mask_scores = self.mask_net(obj_with_context.view(O, -1, 1, 1))
                masks_pred = mask_scores.squeeze(1).sigmoid()

        else:
            boxes_pred = self.box_net(obj_vecs)

            masks_pred = None
            if self.mask_net is not None:
                mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
                masks_pred = mask_scores.squeeze(1).sigmoid()

        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                     obj_to_img, H, W)

        if lstm_hidden is not None:
            #print(lstm_hidden.size()[1],self.lstm_hid_dim)
            assert lstm_hidden.size()[0] == layout.size()[0]
            assert lstm_hidden.size()[1] == self.lstm_hid_dim
            lstm_embedding_vec = self.lstm_embedding(lstm_hidden)
            layout = torch.cat([layout, lstm_embedding_vec], dim=1)
        elif self.layout_noise_dim > 0:  # if not using lstm embedding
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape,
                                       dtype=layout.dtype,
                                       device=layout.device)
            layout = torch.cat([layout, layout_noise], dim=1)
        img = self.refinement_net(layout)
        return img, boxes_pred, masks_pred, rel_scores, context