Example #1
0
    def forward(self, objs, layout_boxes, layout_masks, test_mode=False):
        obj_vecs = self.attribute_embedding.forward(objs)  # [B, N, d']
        seg_batches = []
        for b in range(obj_vecs.size(0)):
            mask = remove_dummy_objects(objs[b], self.opt.vocab)
            objs_vecs_batch = obj_vecs[b][mask]
            layout_boxes_batch = layout_boxes[b][mask]
            # Masks Layout
            if layout_masks is not None:
                layout_masks_batch = layout_masks[b][mask]
                seg = masks_to_layout(objs_vecs_batch,
                                      layout_boxes_batch,
                                      layout_masks_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0],
                                      test_mode=test_mode)
            else:
                # Boxes Layout
                seg = boxes_to_layout(objs_vecs_batch, layout_boxes_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0])
            seg_batches.append(seg)
        seg = torch.cat(seg_batches, dim=0)

        # we downsample segmap and run convolution
        x = F.interpolate(seg, size=(self.sh, self.sw))
        x = self.fc(x)
        x = self.head_0(x, seg)
        x = self.up(x)
        x = self.G_middle_0(x, seg)

        if self.opt.num_upsampling_layers == 'more' or \
                self.opt.num_upsampling_layers == 'most':
            x = self.up(x)

        x = self.G_middle_1(x, seg)

        x = self.up(x)
        x = self.up_0(x, seg)
        x = self.up(x)
        x = self.up_1(x, seg)
        x = self.up(x)
        x = self.up_2(x, seg)
        x = self.up(x)
        x = self.up_3(x, seg)

        if self.opt.num_upsampling_layers == 'most':
            x = self.up(x)
            x = self.up_4(x, seg)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)

        return x
Example #2
0
    def forward(self,
                img,
                objs,
                layout_boxes,
                layout_masks=None,
                gt_train=True,
                fool=False):
        obj_vecs = self.attribute_embedding.forward(objs)  # [B, N, d']

        # Masks Layout
        seg_batches = []
        for b in range(obj_vecs.size(0)):
            mask = remove_dummy_objects(objs[b], self.opt.vocab)
            objs_vecs_batch = obj_vecs[b][mask]
            layout_boxes_batch = layout_boxes[b][mask]

            # Masks Layout
            if layout_masks is not None:
                layout_masks_batch = layout_masks[b][mask]
                seg = masks_to_layout(
                    objs_vecs_batch,
                    layout_boxes_batch,
                    layout_masks_batch,
                    self.opt.image_size[0],
                    self.opt.image_size[0],
                    test_mode=False)  # test mode always false in disc.
            else:
                # Boxes Layout
                seg = boxes_to_layout(objs_vecs_batch, layout_boxes_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0])
            seg_batches.append(seg)

        # layout = torch.cat(layout_batches, dim=0)  # [B, N, d']
        seg = torch.cat(seg_batches, dim=0)
        input = torch.cat([img, seg], dim=1)

        result = []
        get_intermediate_features = not self.opt.no_ganFeat_loss
        for name, D in self.named_children():
            if name.startswith('discriminator'):
                out = D(input)
                if not get_intermediate_features:
                    out = [out]
                result.append(out)
                input = self.downsample(input)
        return result
Example #3
0
    def forward(self,
                objs,
                triples,
                obj_to_img=None,
                boxes_gt=None,
                masks_gt=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)  # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]]  # Now have shape (T,)
        edges = torch.stack([s, o], dim=1)  # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs = self.obj_embeddings(
            objs)  # 'objs' => indices for model.vocab['object_idx_to_name']
        obj_vecs_orig = obj_vecs
        pred_vecs = self.pred_embeddings(
            p)  #  'p' => indices for model.vocab['pred_idx_to_name']

        if isinstance(self.gconv, nn.Linear):
            obj_vecs = self.gconv(obj_vecs)
        else:
            obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        if self.gconv_net is not None:
            obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        # bounding box prediction
        boxes_pred_info = None
        if self.use_bbox_info:
            # bounding box prediction + predicted box info
            boxes_pred_info = self.box_net(obj_vecs)
            boxes_pred = boxes_pred_info[:, 0:
                                         4]  # first 4 entries are bbox coords
        else:
            boxes_pred = self.box_net(obj_vecs)

        masks_pred = None
        layout_masks = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        # this only affects training if loss is non-zero
        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o]
        s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        # uses predicted subject/object boxes, original subject/object embedding (input to GCNN)
        ## use untrained embedding vectors
        ##rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs_pred, o_vecs_pred],
                                  dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        # concatenate triplet vectors
        s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o]
        triplet_input = torch.cat([s_vecs_pred, pred_vecs, o_vecs_pred], dim=1)

        # triplet bounding boxes
        triplet_boxes_pred = None
        if self.triplet_box_net is not None:
            # predict 8 point bounding boxes
            triplet_boxes_pred = self.triplet_box_net(triplet_input)

        # triplet binary masks
        triplet_masks_pred = None
        if self.triplet_mask_net is not None:
            # input dimension must be [h, w, 1, 1]
            triplet_mask_scores = self.triplet_mask_net(triplet_input[:, :,
                                                                      None,
                                                                      None])
            # only used for binary/masks CE loss
            #triplet_masks_pred = triplet_mask_scores.squeeze(1).sigmoid()
            triplet_masks_pred = triplet_mask_scores.squeeze(1)

        # triplet embedding
        triplet_embed = None
        if self.triplet_embed_net is not None:
            triplet_embed = self.triplet_embed_net(triplet_input)

        # triplet superbox
        triplet_superboxes_pred = None
        if self.triplet_superbox_net is not None:
            # predict 8 point bounding boxes
            triplet_superboxes_pred = self.triplet_superbox_net(
                triplet_input)  # s/p/o (bboxes?)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        # compose layout mask
        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                     obj_to_img, H, W)

        layout_crn = layout
        sg_context_pred = None
        sg_context_pred_d = None
        if self.sg_context_net is not None:
            N, C, H, W = layout.size()
            context = sg_context_to_layout(obj_vecs,
                                           obj_to_img,
                                           pooling=self.gcnn_pooling)
            sg_context_pred_sqz = self.sg_context_net(context)

            #### vector to spatial replication
            b = N
            s = self.sg_context_dim
            # b, s = sg_context_pred_sqz.size()
            sg_context_pred = sg_context_pred_sqz.view(b, s, 1, 1).expand(
                b, s, layout.size(2), layout.size(3))
            layout_crn = torch.cat([layout, sg_context_pred], dim=1)

            ## discriminator uses different FC layer than the generator
            sg_context_predd_sqz = self.sg_context_net_d(context)
            s = self.sg_context_dim_d
            sg_context_pred_d = sg_context_predd_sqz.view(b, s, 1, 1).expand(
                b, s, layout.size(2), layout.size(3))

        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape,
                                       dtype=layout.dtype,
                                       device=layout.device)
            layout_crn = torch.cat([layout_crn, layout_noise], dim=1)

        # layout model only
        #img = self.refinement_net(layout_crn)
        img = None

        # compose triplet boxes using 'triplets', objs, etc.
        if boxes_gt is not None:
            s_boxes_gt, o_boxes_gt = boxes_gt[s], boxes_gt[o]
            triplet_boxes_gt = torch.cat([s_boxes_gt, o_boxes_gt], dim=1)
        else:
            triplet_boxes_gt = None

        #return img, boxes_pred, masks_pred, rel_scores
        return img, boxes_pred, masks_pred, objs, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, rel_scores, obj_vecs, pred_vecs, triplet_boxes_pred, triplet_boxes_gt, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred
Example #4
0
    def forward(self,
                obj_to_img,
                boxes_gt,
                obj_fmaps,
                mask_noise_indexes=None,
                masks_gt=None,
                bg_layout=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    """
        assert boxes_gt.max() < 1.1 and boxes_gt.min(
        ) > -0.1, "boxes_gt should be within range [0,1]"
        # O, T = objs.size(0), triples.size(0)
        # s, p, o = triples.chunk(3, dim=1)           # All have shape (T, 1)
        # s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,)
        # edges = torch.stack([s, o], dim=1)          # Shape is (T, 2)
        #
        # if obj_to_img is None:
        #   obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)
        #
        # obj_vecs = self.obj_embeddings(objs)
        # obj_vecs_orig = obj_vecs
        # pred_vecs = self.pred_embeddings(p)
        #
        # if isinstance(self.gconv, nn.Linear):
        #   obj_vecs = self.gconv(obj_vecs)
        # else:
        #   obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        # if self.gconv_net is not None:
        #   obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)
        if self.args.not_decrease_feature_dimension:
            obj_vecs = obj_fmaps
        else:
            obj_vecs = self.obj_fmap_net(obj_fmaps)
        no_noise_obj_vecs = obj_vecs
        if self.args.object_noise_dim > 0:
            # select objs belongs to images in mask_noise_indexes
            if mask_noise_indexes is not None and self.training:
                mask_noise_obj_index_list = []
                for ind in mask_noise_indexes:
                    mask_noise_obj_index_list.append(
                        (obj_to_img == ind).nonzero())
                mask_noise_obj_indexes = torch.cat(mask_noise_obj_index_list,
                                                   dim=0)[:, 0]

            if self.args.noise_apply_method == "concat":
                object_noise = torch.randn(
                    (obj_vecs.shape[0], self.args.object_noise_dim),
                    dtype=obj_vecs.dtype,
                    device=obj_vecs.device)
                if mask_noise_indexes is not None and self.training:
                    object_noise[mask_noise_obj_indexes] = 0
                obj_vecs = torch.cat([obj_vecs, object_noise], dim=1)
            elif self.args.noise_apply_method == "add":
                object_noise = torch.randn(obj_vecs.shape,
                                           dtype=obj_vecs.dtype,
                                           device=obj_vecs.device)
                if mask_noise_indexes is not None and self.training:
                    object_noise[mask_noise_obj_indexes] = 0
                obj_vecs = obj_vecs + object_noise

        # boxes_pred = self.box_net(obj_vecs)

        masks_pred = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(
                obj_vecs.view(obj_vecs.shape[0], -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        # s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        # s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        # rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        # rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        # layout_boxes = boxes_pred if boxes_gt is None else boxes_gt
        layout_boxes = boxes_gt

        # layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)

        if masks_pred is None:
            if self.args.object_no_noise_with_bbox:
                layout = boxes_to_layout(no_noise_obj_vecs, layout_boxes,
                                         obj_to_img, H, W)
            else:
                layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H,
                                         W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            if self.args.object_no_noise_with_mask:
                layout = masks_to_layout(no_noise_obj_vecs, layout_boxes,
                                         layout_masks, obj_to_img, H, W)
            else:
                layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                         obj_to_img, H, W)
        ret_layout = layout

        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            if self.args.noise_apply_method == "concat":
                noise_shape = (N, self.layout_noise_dim, H, W)
            elif self.args.noise_apply_method == "add":
                noise_shape = layout.shape
            # print("check noise_std here, it is %.10f" % self.args.noise_std)
            noise_std = torch.zeros(noise_shape,
                                    dtype=layout.dtype,
                                    device=layout.device).fill_(
                                        self.args.noise_std)
            layout_noise = torch.normal(mean=0.0, std=noise_std)
            if self.args.layout_noise_only_on_foreground:
                layout_noise *= (1 - bg_layout[:, :1, :, :].repeat(
                    1, self.layout_noise_dim, 1, 1))

            if mask_noise_indexes is not None and self.training:
                layout_noise[mask_noise_indexes] = 0.
            # layout_noise = torch.randn(noise_shape, dtype=layout.dtype,
            #                            device=layout.device)
            if self.args.noise_apply_method == "concat":
                layout = torch.cat([layout, layout_noise], dim=1)
            elif self.args.noise_apply_method == "add":
                layout = layout + layout_noise
        img = self.refinement_net(layout)
        return img, ret_layout
Example #5
0
    def forward(self, objs, triples, obj_to_img=None, pred_to_img=None,
              boxes_gt=None, masks_gt=None):
        """
        Required Inputs:
        - objs: LongTensor of shape (O,) giving categories for all objects
        - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
          means that there is a triple (objs[s], p, objs[o])

        Optional Inputs:
        - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
          means that objects[o] is an object in image i. If not given then
          all objects are assumed to belong to the same image.
        - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
          the spatial layout; if not given then use predicted boxes.
        """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)           # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,)
        edges = torch.stack([s, o], dim=1)          # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs, pred_vecs = self.embedding(objs, p)
        obj_vecs_orig = obj_vecs

        obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        boxes_pred = self.box_net(obj_vecs)
        
        masks_pred = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                   obj_to_img, H, W)

        # Add context embedding
#         context = self.context_network(pred_vecs)
        # TODO how to concatenate this?

#         if self.layout_noise_dim > 0:
#             N, C, H, W = layout.size()
#             # Concatenate noise with new context embedding and make proper shape
#             noise = torch.randn(N, self.layout_noise_dim)
#             noise = noise.view(noise.size(0), self.layout_noise_dim)
#             z = torch.cat([noise,proj_c],1)
#             layout_noise = self.noise_layout(z)
#             layout = torch.cat([layout, layout_noise], dim=1)
        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape, dtype=layout.dtype,
                                     device=layout.device)
            layout = torch.cat([layout, layout_noise], dim=1)

        img = self.refinement_net(layout)
        return img, boxes_pred, masks_pred, rel_scores
Example #6
0
    def forward(self, imgs, img_offset, gt_boxes, gt_classes, gt_fmaps):
        obj_to_img = gt_classes[:, 0] - img_offset
        # print("obj_to_img.min(), obj_to_img.max(), len(imgs) {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs)))
        assert obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs), \
            "obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs) is not satidfied: {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs))
        boxes = gt_boxes
        obj_fmaps = gt_fmaps
        objs = gt_classes[:, 1]

        if self.args is not None:
            if self.args.exchange_feat_cls:
                print("exchange feature vectors and classes among bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    # permute = ind[torch.randperm(len(ind))]
                    # obj_fmaps[ind] = obj_fmaps[permute]
                    permute_ind = ind[torch.randperm(len(ind))[:2]]
                    permute = permute_ind[[1, 0]]
                    obj_fmaps[permute_ind] = obj_fmaps[permute]
                    objs[permute_ind] = objs[permute]

            if self.args.change_bbox:
                print("change the position of bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    ind = ind[torch.randperm(len(ind))[0]]
                    if boxes[ind][3] < 0.8:
                        print("move to bottom")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                    elif boxes[ind][1] > 0.2:
                        print("move to top")
                        boxes[ind][3] -= boxes[ind][1]
                        boxes[ind][1] = 0
                    elif boxes[ind][0] > 0.2:
                        print("move to left")
                        boxes[ind][2] -= boxes[ind][0]
                        boxes[ind][0] = 0
                    elif boxes[ind][2] < 0.8:
                        print("move to right")
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1
                    else:
                        print("move to bottom right")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1

        # obj_to_img, boxes, obj_fmaps, mask_noise_indexes
        half_size = imgs.shape[0] // 2

        obj_index_encoded = []
        obj_index_random = []
        for ind in range(half_size):
            obj_index_encoded.append((obj_to_img == ind).nonzero()[:, 0])
        obj_index_encoded = torch.cat(obj_index_encoded)
        for ind in range(half_size, imgs.shape[0]):
            obj_index_random.append((obj_to_img == ind).nonzero()[:, 0])
        obj_index_random = torch.cat(obj_index_random)

        imgs_encoded = imgs[:half_size]
        obj_to_img_encoded = obj_to_img[obj_index_encoded]
        boxes_encoded = boxes[obj_index_encoded]
        obj_fmaps_encoded = obj_fmaps[obj_index_encoded]
        mask_noise_indexes_encoded = torch.randperm(
            half_size)[:int(self.args.noise_mask_ratio * half_size)].to(
                imgs.device)
        if len(mask_noise_indexes_encoded) == 0:
            mask_noise_indexes_encoded = None
        crops_encoded = crop_bbox_batch(imgs_encoded, boxes_encoded,
                                        obj_to_img_encoded,
                                        self.args.crop_size)

        imgs_random = imgs[half_size:]
        obj_to_img_random = obj_to_img[obj_index_random] - half_size
        boxes_random = boxes[obj_index_random]
        obj_fmaps_random = obj_fmaps[obj_index_random]
        mask_noise_indexes_random = torch.randperm(imgs.shape[0] - half_size)\
            [:int(self.args.noise_mask_ratio * (imgs.shape[0] - half_size))].to(imgs.device)
        if len(mask_noise_indexes_random) == 0:
            mask_noise_indexes_random = None
        # crops_random = crop_bbox_batch(imgs_random, boxes_random, obj_to_img_random, self.args.crop_size)

        mask_noise_indexes = None
        if mask_noise_indexes_encoded is not None:
            mask_noise_indexes = mask_noise_indexes_encoded
        if mask_noise_indexes_random is not None:
            if mask_noise_indexes is not None:
                mask_noise_indexes = torch.cat([
                    mask_noise_indexes, mask_noise_indexes_random + half_size
                ])
            else:
                mask_noise_indexes = mask_noise_indexes_random + half_size

        if self.forward_G:
            with timeit('generator forward', self.args.timing):
                if self.training:
                    mu_encoded, logvar_encoded = self.obj_encoder(
                        crops_encoded)
                    std = logvar_encoded.mul(0.5).exp_()
                    eps = torch.randn((std.size(0), std.size(1)),
                                      dtype=std.dtype,
                                      device=std.device)
                    z_encoded = eps.mul(std).add_(mu_encoded)
                    z_random = torch.randn((obj_fmaps_random.shape[0],
                                            self.args.object_noise_dim),
                                           dtype=obj_fmaps_random.dtype,
                                           device=obj_fmaps_random.device)

                    imgs_pred_encoded, layout_encoded = self.model(
                        obj_to_img_encoded,
                        boxes_encoded,
                        obj_fmaps_encoded,
                        mask_noise_indexes=mask_noise_indexes_encoded,
                        object_noise=z_encoded)
                    imgs_pred_random, layout_random = self.model(
                        obj_to_img_random,
                        boxes_random,
                        obj_fmaps_random,
                        mask_noise_indexes=mask_noise_indexes_random,
                        object_noise=z_random)

                    crops_pred_encoded = crop_bbox_batch(
                        imgs_pred_encoded, boxes_encoded, obj_to_img_encoded,
                        self.args.crop_size)

                    crops_pred_random = crop_bbox_batch(
                        imgs_pred_random, boxes_random, obj_to_img_random,
                        self.args.crop_size)
                    mu_rec, logvar_rec = self.obj_encoder(crops_pred_random)
                    z_random_rec = mu_rec

                    imgs_pred = torch.cat(
                        [imgs_pred_encoded, imgs_pred_random], dim=0)

                    layout = torch.cat([layout_encoded, layout_random],
                                       dim=0).detach()
                else:
                    z_random = torch.randn(
                        (obj_fmaps.shape[0], self.args.object_noise_dim),
                        dtype=obj_fmaps.dtype,
                        device=obj_fmaps.device)
                    imgs_pred, layout = self.model(
                        obj_to_img,
                        boxes,
                        obj_fmaps,
                        mask_noise_indexes=mask_noise_indexes,
                        object_noise=z_random)
                    layout = layout.detach()
                    crops_encoded = None
                    crops_pred_encoded = None
                    z_random = None
                    z_random_rec = None
                    mu_encoded = None
                    logvar_encoded = None

        H, W = self.args.image_size
        bg_layout = boxes_to_layout(
            torch.ones(boxes.shape[0], 3).to(imgs.device), boxes, obj_to_img,
            H, W)
        bg_layout = (bg_layout <= 0).type(imgs.dtype)

        if self.args.condition_d_img_on_class_label_map:
            layout = boxes_to_layout(
                (objs + 1).view(-1, 1).repeat(1, 3).type(imgs.dtype), boxes,
                obj_to_img, H, W)

        g_scores_fake_crop, g_obj_scores_fake_crop, g_rec_feature_fake_crop = None, None, None
        g_scores_fake_img = None
        g_scores_fake_bg = None
        if self.calc_G_D_loss:
            # forward discriminators to train generator
            if self.obj_discriminator is not None:
                with timeit('d_obj forward for g', self.args.timing):
                    g_scores_fake_crop, g_obj_scores_fake_crop, _, g_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_pred, objs, boxes, obj_to_img)

            if self.img_discriminator is not None:
                with timeit('d_img forward for g', self.args.timing):
                    if self.args.condition_d_img:
                        g_scores_fake_img = self.img_discriminator(
                            imgs_pred, layout)
                    else:
                        g_scores_fake_img = self.img_discriminator(imgs_pred)

            if self.bg_discriminator is not None:
                with timeit('d_bg forward for g', self.args.timing):
                    if self.args.condition_d_bg:
                        g_scores_fake_bg = self.bg_discriminator(
                            imgs_pred, bg_layout)
                    else:
                        g_scores_fake_bg = self.bg_discriminator(imgs_pred *
                                                                 bg_layout)

        d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = None, None, None, None
        d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = None, None, None, None
        d_obj_gp = None
        d_scores_fake_img = None
        d_scores_real_img = None
        d_img_gp = None
        d_scores_fake_bg = None
        d_scores_real_bg = None
        d_bg_gp = None
        if self.forward_D:
            # forward discriminators to train discriminators
            if self.obj_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_obj forward for d', self.args.timing):
                    d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
                    d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = \
                        self.obj_discriminator(imgs, objs, boxes, obj_to_img)
                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        d_obj_gp = gradient_penalty(
                            real_crops.detach(), fake_crops.detach(),
                            self.obj_discriminator.discriminator)

            if self.img_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_img forward for d', self.args.timing):
                    if self.args.condition_d_img:
                        d_scores_fake_img = self.img_discriminator(
                            imgs_fake, layout)
                        d_scores_real_img = self.img_discriminator(
                            imgs, layout)
                    else:
                        d_scores_fake_img = self.img_discriminator(imgs_fake)
                        d_scores_real_img = self.img_discriminator(imgs)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_img:
                            d_img_gp = gradient_penalty(
                                torch.cat([imgs, layout], dim=1),
                                torch.cat([imgs_fake, layout], dim=1),
                                self.img_discriminator)
                        else:
                            d_img_gp = gradient_penalty(
                                imgs, imgs_fake, self.img_discriminator)

            if self.bg_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_bg forward for d', self.args.timing):
                    if self.args.condition_d_bg:
                        d_scores_fake_bg = self.bg_discriminator(
                            imgs_fake, bg_layout)
                        d_scores_real_bg = self.bg_discriminator(
                            imgs, bg_layout)
                    else:
                        d_scores_fake_bg = self.bg_discriminator(imgs_fake *
                                                                 bg_layout)
                        d_scores_real_bg = self.bg_discriminator(imgs *
                                                                 bg_layout)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_bg:
                            d_bg_gp = gradient_penalty(
                                torch.cat([imgs, bg_layout], dim=1),
                                torch.cat([imgs_fake, bg_layout], dim=1),
                                self.bg_discriminator)
                        else:
                            d_bg_gp = gradient_penalty(imgs * bg_layout,
                                                       imgs_fake * bg_layout,
                                                       self.bg_discriminator)
        return Result(imgs=imgs,
                      imgs_pred=imgs_pred,
                      objs=objs,
                      obj_fmaps=obj_fmaps,
                      boxes=boxes,
                      obj_to_img=obj_to_img + img_offset,
                      g_scores_fake_crop=g_scores_fake_crop,
                      g_obj_scores_fake_crop=g_obj_scores_fake_crop,
                      g_scores_fake_img=g_scores_fake_img,
                      d_scores_fake_crop=d_scores_fake_crop,
                      d_obj_scores_fake_crop=d_obj_scores_fake_crop,
                      d_scores_real_crop=d_scores_real_crop,
                      d_obj_scores_real_crop=d_obj_scores_real_crop,
                      d_scores_fake_img=d_scores_fake_img,
                      d_scores_real_img=d_scores_real_img,
                      d_obj_gp=d_obj_gp,
                      d_img_gp=d_img_gp,
                      fake_crops=fake_crops,
                      real_crops=real_crops,
                      mask_noise_indexes=(mask_noise_indexes + img_offset)
                      if mask_noise_indexes is not None else None,
                      g_rec_feature_fake_crop=g_rec_feature_fake_crop,
                      d_rec_feature_fake_crop=d_rec_feature_fake_crop,
                      d_rec_feature_real_crop=d_rec_feature_real_crop,
                      g_scores_fake_bg=g_scores_fake_bg,
                      d_scores_fake_bg=d_scores_fake_bg,
                      d_scores_real_bg=d_scores_real_bg,
                      d_bg_gp=d_bg_gp,
                      bg_layout=bg_layout,
                      crops_encoded=crops_encoded,
                      crops_pred_encoded=crops_pred_encoded,
                      z_random=z_random,
                      z_random_rec=z_random_rec,
                      mu_encoded=mu_encoded,
                      logvar_encoded=logvar_encoded)
Example #7
0
    def forward(self,
                objs,
                triples,
                obj_to_img=None,
                boxes_gt=None,
                masks_gt=None,
                tr_to_img=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)  # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]]  # Now have shape (T,)
        edges = torch.stack([s, o], dim=1)  # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs = self.obj_embeddings(
            objs)  # 'objs' => indices for model.vocab['object_idx_to_name']
        obj_vecs_orig = obj_vecs
        pred_vecs = self.pred_embeddings(
            p)  #  'p' => indices for model.vocab['pred_idx_to_name']
        pred_vecs_orig = pred_vecs

        if isinstance(self.gconv, nn.Linear):
            obj_vecs = self.gconv(obj_vecs)
        else:
            obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        if self.gconv_net is not None:
            obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        #### object context vectors ###############
        num_imgs = obj_to_img[obj_to_img.size(0) - 1] + 1
        context_obj_vecs = torch.zeros(num_imgs,
                                       obj_vecs.size(1),
                                       dtype=obj_vecs.dtype,
                                       device=obj_vecs.device)
        obj_to_img_exp = obj_to_img.view(-1, 1).expand_as(obj_vecs)
        context_obj_vecs = context_obj_vecs.scatter_add(
            0, obj_to_img_exp, obj_vecs)

        # get object counts
        obj_counts = torch.zeros(num_imgs,
                                 dtype=obj_vecs.dtype,
                                 device=obj_vecs.device)
        ones = torch.ones(obj_to_img.size(0),
                          dtype=obj_vecs.dtype,
                          device=obj_vecs.device)
        obj_counts = obj_counts.scatter_add(0, obj_to_img, ones)
        context_obj_vecs = context_obj_vecs / obj_counts.view(-1, 1)
        context_obj_vecs = context_obj_vecs[obj_to_img]
        context_obj_vecs = context_obj_vecs[s]
        ####################################

        ####### triplet context vectors ###########
        #context_tr_vecs = None
        # concatenate triplet vectors
        #triplets = torch.cat([obj_vecs[s], pred_vecs, obj_vecs[o]], dim=1)
        #context_tr_vecs = torch.zeros(num_imgs, triplets.size(1), dtype=obj_vecs.dtype, device=obj_vecs.device)
        # need triplet to image
        #tr_to_img_exp = tr_to_img.view(-1, 1).expand_as(triplets)
        #context_tr_vecs = context_tr_vecs.scatter_add(0, tr_to_img_exp, triplets)
        # get triplet counts
        #tr_counts = torch.zeros(num_imgs, dtype=obj_vecs.dtype, device=obj_vecs.device)
        #ones = torch.ones(triplets.size(0), dtype=obj_vecs.dtype, device=obj_vecs.device)
        #tr_counts = tr_counts.scatter_add(0, tr_to_img, ones)
        #context_tr_vecs = context_tr_vecs/tr_counts.view(-1,1)
        # dimension is (# triplets, 3*input_dim)
        #context_tr_vecs = context_tr_vecs[tr_to_img]
        # get some context!
        #context_tr_vecs = self.triplet_context_net(context_tr_vecs)
        ###########################################

        ####  mask out some predicates #####
        pred_mask_gt = None
        pred_mask_scores = None
        if self.use_masked_sg:
            perc = torch.FloatTensor([0.50])  # hyperparameter
            num_mask_objs = torch.floor(perc *
                                        len(s)).cpu().numpy()[0].astype(int)
            if num_mask_objs < 1:
                num_mask_objs = 1
            mask_idx = torch.randint(0, len(s) - 1, (num_mask_objs, ))
            #rand_idx = torch.randperm(len(s)-1)
            #mask_idx = rand_idx[:num_mask_objs]
            # GT
            pred_mask_gt = p[mask_idx.long()]  # return
            # set mask idx to masked embedding (e.g. new SG!)
            pred_vecs_copy = pred_vecs_orig
            ##### need to add i=46 None embedding
            pred_vecs_copy[mask_idx.long()] = self.pred_embeddings(
                torch.tensor([self.mask_pred]).cuda())

            # convolve new masked SG
            if isinstance(self.gconv, nn.Linear):
                mask_obj_vecs = self.gconv(obj_vecs_orig)
            else:
                mask_obj_vecs, mask_pred_vecs = self.gconv(
                    obj_vecs_orig, pred_vecs_copy, edges)
            if self.gconv_net is not None:
                mask_obj_vecs, mask_pred_vecs = self.gconv_net(
                    mask_obj_vecs, mask_pred_vecs, edges)

            # subj/obj obj idx
            s_mask = s[mask_idx.long()]
            o_mask = o[mask_idx.long()]

            subj_vecs_mask = mask_obj_vecs[s_mask]
            obj_vecs_mask = mask_obj_vecs[o_mask]

            # predict masked predicate relationship
            pred_mask_input = torch.cat([subj_vecs_mask, obj_vecs_mask], dim=1)
            pred_mask_scores = self.pred_mask_net(pred_mask_input)
        #####################

        # bounding box prediction
        boxes_pred_info = None
        if self.use_bbox_info:
            # bounding box prediction + predicted box info
            boxes_pred_info = self.box_net(obj_vecs)
            boxes_pred = boxes_pred_info[:, 0:
                                         4]  # first 4 entries are bbox coords
        else:
            boxes_pred = self.box_net(obj_vecs)

        masks_pred = None
        layout_masks = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        # predicted bboxes and embedding vectors
        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o]
        # input embedding vectors
        s_vecs, o_vecs, p_vecs = obj_vecs_orig[s], obj_vecs_orig[
            o], pred_vecs_orig
        input_tr_vecs = torch.cat([s_vecs, p_vecs, o_vecs], dim=1)

        # VSA (with obj/pred vectors of varying kinds)
        fr_obj_vecs = self.fr_obj_embeddings(objs)
        fr_pred_vecs = self.fr_pred_embeddings(p)
        fr_s_vecs, fr_o_vecs = fr_obj_vecs[s], fr_obj_vecs[o]
        mapc_bind = fr_s_vecs * fr_o_vecs * fr_pred_vecs
        # mapc_bind = s_vecs * o_vecs * p_vecs
        #mapc_bind = s_vecs_pred * o_vecs_pred * pred_vecs
        mapc_bind = F.normalize(mapc_bind, p=2, dim=1)

        # uses predicted subject/object boxes, original subject/object embedding (input to GCNN)
        ## use original embedding vectors
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        # subject prediction
        subj_aux_input = torch.cat([s_boxes, o_boxes, p_vecs, o_vecs], dim=1)
        subj_scores = self.subj_aux_net(subj_aux_input)

        # object prediction
        obj_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, p_vecs], dim=1)
        obj_scores = self.obj_aux_net(obj_aux_input)

        # object class prediction (for output object vectors)
        obj_class_scores = self.obj_class_aux_net(obj_vecs)

        # relationship class prediction (for output object vectors)
        # relationship embedding (very small embedding)
        # augment relationship embedding
        use_augmentation = False
        mask_rel_embedding = None
        if use_augmentation:
            num_augs = 4
            num_preds = len(p)
            mask_rel_embedding = torch.zeros(
                (num_preds, num_augs + 1, self.embedding_dim),
                dtype=obj_vecs.dtype,
                device=obj_vecs.device)
            rel_embedding = []
            #pred_vecs_mask = np.zeros((num_preds,num_augs,self.embedding_dim))
            # mask embedding perc_aug of vectors with 0
            perc_aug = torch.FloatTensor([0.4])  # hyperparameter
            num_mask = torch.floor(
                perc_aug * len(pred_vecs[0])).cpu().numpy()[0].astype(int)
            #p_ids = []
            for i in range(num_preds):
                pred = pred_vecs[i]
                #p_id = p[i]
                vecs = [pred]
                #p_ids += [p_id]
                for j in range(num_augs):
                    # pick a random set of indices to zero-out
                    rand_idx = torch.randperm(len(
                        pred_vecs[0]))  # 0-127 range shuffled
                    mask_idx = rand_idx[:num_mask]
                    pred_mask = pred.detach().clone()
                    pred_mask[mask_idx] = 0.0
                    vecs += [pred_mask]
                # project masked augmented relationship vectors
                pred_mask = self.rel_embed_aux_net(
                    torch.stack(vecs))  # output predicate embeddings
                pred_mask = F.normalize(pred_mask, dim=1)
                mask_rel_embedding[i, :, :] = pred_mask
                rel_embedding += [pred_mask[0]]

        #rel_embedding = torch.stack(rel_embedding)
        # projection head for supervised contrastive loss
        rel_embedding = self.rel_embed_aux_net(
            pred_vecs)  # output projected predicate embeddings
        rel_embedding = F.normalize(rel_embedding, dim=1)

        # relationship class prediction on predicates
        rel_class_scores = self.rel_class_aux_net(pred_vecs)

        # concatenate triplet vectors
        s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o]
        triplet_input = torch.cat([s_vecs_pred, pred_vecs, o_vecs_pred], dim=1)

        # triplet bounding boxes
        triplet_boxes_pred = None
        if self.triplet_box_net is not None:
            # predict 8 point bounding boxes
            triplet_boxes_pred = self.triplet_box_net(triplet_input)

        # triplet binary masks
        triplet_masks_pred = None
        if self.triplet_mask_net is not None:
            # input dimension must be [h, w, 1, 1]
            triplet_mask_scores = self.triplet_mask_net(triplet_input[:, :,
                                                                      None,
                                                                      None])
            # only used for binary/masks CE loss
            #triplet_masks_pred = triplet_mask_scores.squeeze(1).sigmoid()
            triplet_masks_pred = triplet_mask_scores.squeeze(1)

        # triplet embedding
        triplet_embed = None
        if self.triplet_embed_net is not None:
            triplet_embed = self.triplet_embed_net(triplet_input)

        # triplet superbox
        triplet_superboxes_pred = None
        if self.triplet_superbox_net is not None:
            # predict 2 point superboxes
            triplet_superboxes_pred = self.triplet_superbox_net(
                triplet_input)  # s/p/o (bboxes?)

        # predicate grounding
        pred_ground = None
        if self.pred_ground_net is not None:
            # predict 2 point pred grounding
            pred_ground = self.pred_ground_net(pred_vecs)  # s/p/o (bboxes?)

        # triplet context
        triplet_context_input = torch.cat(
            [context_obj_vecs, s_vecs_pred, pred_vecs, o_vecs_pred], dim=1)
        # output dimension is 384
        context_tr_vecs = self.triplet_context_net(triplet_context_input)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        # compose layout mask
        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                     obj_to_img, H, W)
        layout_crn = layout
        sg_context_pred = None
        sg_context_pred_d = None
        if self.sg_context_net is not None:
            N, C, H, W = layout.size()
            context = sg_context_to_layout(obj_vecs,
                                           obj_to_img,
                                           pooling=self.gcnn_pooling)
            sg_context_pred_sqz = self.sg_context_net(context)

            #### vector to spatial replication
            b = N
            s = self.sg_context_dim
            # b, s = sg_context_pred_sqz.size()
            sg_context_pred = sg_context_pred_sqz.view(b, s, 1, 1).expand(
                b, s, layout.size(2), layout.size(3))
            layout_crn = torch.cat([layout, sg_context_pred], dim=1)

            ## discriminator uses different FC layer than the generator
            sg_context_predd_sqz = self.sg_context_net_d(context)
            s = self.sg_context_dim_d
            sg_context_pred_d = sg_context_predd_sqz.view(b, s, 1, 1).expand(
                b, s, layout.size(2), layout.size(3))

        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape,
                                       dtype=layout.dtype,
                                       device=layout.device)
            layout_crn = torch.cat([layout_crn, layout_noise], dim=1)

        # layout model only
        #img = self.refinement_net(layout_crn)
        img = None

        # compose triplet boxes using 'triplets', objs, etc.
        if boxes_gt is not None:
            s_boxes_gt, o_boxes_gt = boxes_gt[s], boxes_gt[o]
            triplet_boxes_gt = torch.cat([s_boxes_gt, o_boxes_gt], dim=1)
        else:
            triplet_boxes_gt = None

        #return img, boxes_pred, masks_pred, rel_scores
        return img, boxes_pred, masks_pred, objs, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, rel_scores, obj_vecs, pred_vecs, triplet_boxes_pred, triplet_boxes_gt, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred, obj_scores, pred_mask_gt, pred_mask_scores, context_tr_vecs, input_tr_vecs, obj_class_scores, rel_class_scores, subj_scores, rel_embedding, mask_rel_embedding, pred_ground  #, mapc_bind
    def forward(self, imgs, img_offset, gt_boxes, gt_classes, gt_fmaps):
        obj_to_img = gt_classes[:, 0] - img_offset
        # print("obj_to_img.min(), obj_to_img.max(), len(imgs) {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs)))
        assert obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs), \
            "obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs) is not satidfied: {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs))
        boxes = gt_boxes
        obj_fmaps = gt_fmaps
        objs = gt_classes[:, 1]

        if self.args is not None:
            if self.args.exchange_feat_cls:
                print("exchange feature vectors and classes among bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    # permute = ind[torch.randperm(len(ind))]
                    # obj_fmaps[ind] = obj_fmaps[permute]
                    permute_ind = ind[torch.randperm(len(ind))[:2]]
                    permute = permute_ind[[1, 0]]
                    obj_fmaps[permute_ind] = obj_fmaps[permute]
                    objs[permute_ind] = objs[permute]

            if self.args.change_bbox:
                print("change the position of bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    ind = ind[torch.randperm(len(ind))[0]]
                    if boxes[ind][3] < 0.8:
                        print("move to bottom")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                    elif boxes[ind][1] > 0.2:
                        print("move to top")
                        boxes[ind][3] -= boxes[ind][1]
                        boxes[ind][1] = 0
                    elif boxes[ind][0] > 0.2:
                        print("move to left")
                        boxes[ind][2] -= boxes[ind][0]
                        boxes[ind][0] = 0
                    elif boxes[ind][2] < 0.8:
                        print("move to right")
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1
                    else:
                        print("move to bottom right")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1

        mask_noise_indexes = torch.randperm(
            imgs.shape[0])[:int(self.args.noise_mask_ratio *
                                imgs.shape[0])].to(imgs.device)
        if len(mask_noise_indexes) == 0:
            mask_noise_indexes = None

        if self.forward_G:
            with timeit('generator forward', self.args.timing):
                imgs_pred, layout, z_random = self.model(
                    obj_to_img, boxes, obj_fmaps, mask_noise_indexes)

                if self.training:
                    mu_rec, logvar_rec = self.img_encoder(imgs_pred)
                    z_random_rec = mu_rec
                else:
                    z_random_rec = None

        H, W = self.args.image_size
        bg_layout = boxes_to_layout(
            torch.ones(boxes.shape[0], 3).to(imgs.device), boxes, obj_to_img,
            H, W)
        bg_layout = (bg_layout <= 0).type(imgs.dtype)

        if self.args.condition_d_img_on_class_label_map:
            layout = boxes_to_layout(
                (objs + 1).view(-1, 1).repeat(1, 3).type(imgs.dtype), boxes,
                obj_to_img, H, W)

        g_scores_fake_crop, g_obj_scores_fake_crop, g_rec_feature_fake_crop = None, None, None
        g_scores_fake_img = None
        g_scores_fake_bg = None
        if self.calc_G_D_loss:
            # forward discriminators to train generator
            if self.obj_discriminator is not None:
                with timeit('d_obj forward for g', self.args.timing):
                    g_scores_fake_crop, g_obj_scores_fake_crop, _, g_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_pred, objs, boxes, obj_to_img)

            if self.img_discriminator is not None:
                with timeit('d_img forward for g', self.args.timing):
                    if self.args.condition_d_img:
                        g_scores_fake_img = self.img_discriminator(
                            imgs_pred, layout)
                    else:
                        g_scores_fake_img = self.img_discriminator(imgs_pred)

            if self.bg_discriminator is not None:
                with timeit('d_bg forward for g', self.args.timing):
                    if self.args.condition_d_bg:
                        g_scores_fake_bg = self.bg_discriminator(
                            imgs_pred, bg_layout)
                    else:
                        g_scores_fake_bg = self.bg_discriminator(imgs_pred *
                                                                 bg_layout)

        d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = None, None, None, None
        d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = None, None, None, None
        d_obj_gp = None
        d_scores_fake_img = None
        d_scores_real_img = None
        d_img_gp = None
        d_scores_fake_bg = None
        d_scores_real_bg = None
        d_bg_gp = None
        if self.forward_D:
            # forward discriminators to train discriminators
            if self.obj_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_obj forward for d', self.args.timing):
                    d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
                    d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = \
                        self.obj_discriminator(imgs, objs, boxes, obj_to_img)
                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        d_obj_gp = gradient_penalty(
                            real_crops.detach(), fake_crops.detach(),
                            self.obj_discriminator.discriminator)

            if self.img_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_img forward for d', self.args.timing):
                    if self.args.condition_d_img:
                        d_scores_fake_img = self.img_discriminator(
                            imgs_fake, layout)
                        d_scores_real_img = self.img_discriminator(
                            imgs, layout)
                    else:
                        d_scores_fake_img = self.img_discriminator(imgs_fake)
                        d_scores_real_img = self.img_discriminator(imgs)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_img:
                            d_img_gp = gradient_penalty(
                                torch.cat([imgs, layout], dim=1),
                                torch.cat([imgs_fake, layout], dim=1),
                                self.img_discriminator)
                        else:
                            d_img_gp = gradient_penalty(
                                imgs, imgs_fake, self.img_discriminator)

            if self.bg_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_bg forward for d', self.args.timing):
                    if self.args.condition_d_bg:
                        d_scores_fake_bg = self.bg_discriminator(
                            imgs_fake, bg_layout)
                        d_scores_real_bg = self.bg_discriminator(
                            imgs, bg_layout)
                    else:
                        d_scores_fake_bg = self.bg_discriminator(imgs_fake *
                                                                 bg_layout)
                        d_scores_real_bg = self.bg_discriminator(imgs *
                                                                 bg_layout)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_bg:
                            d_bg_gp = gradient_penalty(
                                torch.cat([imgs, bg_layout], dim=1),
                                torch.cat([imgs_fake, bg_layout], dim=1),
                                self.bg_discriminator)
                        else:
                            d_bg_gp = gradient_penalty(imgs * bg_layout,
                                                       imgs_fake * bg_layout,
                                                       self.bg_discriminator)
        return Result(imgs=imgs,
                      imgs_pred=imgs_pred,
                      objs=objs,
                      obj_fmaps=obj_fmaps,
                      boxes=boxes,
                      obj_to_img=obj_to_img + img_offset,
                      g_scores_fake_crop=g_scores_fake_crop,
                      g_obj_scores_fake_crop=g_obj_scores_fake_crop,
                      g_scores_fake_img=g_scores_fake_img,
                      d_scores_fake_crop=d_scores_fake_crop,
                      d_obj_scores_fake_crop=d_obj_scores_fake_crop,
                      d_scores_real_crop=d_scores_real_crop,
                      d_obj_scores_real_crop=d_obj_scores_real_crop,
                      d_scores_fake_img=d_scores_fake_img,
                      d_scores_real_img=d_scores_real_img,
                      d_obj_gp=d_obj_gp,
                      d_img_gp=d_img_gp,
                      fake_crops=fake_crops,
                      real_crops=real_crops,
                      mask_noise_indexes=(mask_noise_indexes + img_offset)
                      if mask_noise_indexes is not None else None,
                      g_rec_feature_fake_crop=g_rec_feature_fake_crop,
                      d_rec_feature_fake_crop=d_rec_feature_fake_crop,
                      d_rec_feature_real_crop=d_rec_feature_real_crop,
                      g_scores_fake_bg=g_scores_fake_bg,
                      d_scores_fake_bg=d_scores_fake_bg,
                      d_scores_real_bg=d_scores_real_bg,
                      d_bg_gp=d_bg_gp,
                      bg_layout=bg_layout,
                      z_random=z_random,
                      z_random_rec=z_random_rec)
Example #9
0
    def forward(self, obj_vecs, boxes_pred, triples, obj_to_img=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    """
        O = obj_vecs.size(0)
        # O, T = objs.size(0), triples.size(0)
        # s, p, o = triples.chunk(3, dim=1)           # All have shape (T, 1)
        # s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,)
        # edges = torch.stack([s, o], dim=1)          # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)
        #hassan change here
        # obj_vecs = self.obj_embeddings(objs)
        # obj_vecs_orig = obj_vecs
        # pred_vecs = self.pred_embeddings(p)

        # if isinstance(self.gconv, nn.Linear):
        #   obj_vecs = self.gconv(obj_vecs)
        # else:
        #   obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        # if self.gconv_net is not None:
        #   obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        #boxes_pred = self.box_net(obj_vecs)
        #print('Boxes pred:',boxes_pred.shape)

        # masks_pred = None

        # mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
        # mask_pred = mask_scores.squeeze(1).sigmoid()

        # print('mask_pred:',mask_pred.shape)
        # print('mask_scores:',mask_scores.shape)

        # s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        # s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        # rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        # rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        layout_boxes = boxes_pred

        layout = boxes_to_layout(obj_vecs.cuda(), layout_boxes.cuda(),
                                 obj_to_img.cuda(), H, W)
        #print('\n\nlayouts from boxes:',layout.shape)
        # if self.layout_noise_dim > 0:
        #   N, C, H, W = layout.size()
        #   noise_shape = (N, self.layout_noise_dim, H, W)
        #   layout_noise = torch.randn(noise_shape, dtype=layout.dtype,
        #                              device=layout.device)
        #   layout = torch.cat([layout, layout_noise], dim=1)

        img = self.refinement_net(layout)
        return img
Example #10
0
    def forward(self,
                objs,
                triples,
                obj_to_img=None,
                boxes_gt=None,
                masks_gt=None):
        """ 
        Required Inputs:
        - objs: LongTensor of shape (O,) giving categories for all objects
        - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
          means that there is a triple (objs[s], p, objs[o])
        
        main Process: graph >>> graph conv >>> layout >>> CRN.

        Optional Inputs:
        - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
          means that objects[o] is an object in image i. If not given then
          all objects are assumed to belong to the same image.
        - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
          the spatial layout; if not given then use predicted boxes.
        """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)  # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]]  # Now have shape (T,)
        # edges specify start and end.
        edges = torch.stack([s, o], dim=1)  # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs = self.obj_embeddings(objs)
        obj_vecs_orig = obj_vecs
        pred_vecs = self.pred_embeddings(p)

        # Graph convolutional network.
        if isinstance(self.gconv, nn.Linear):
            obj_vecs = self.gconv(obj_vecs)
        else:
            # what's the difference between gconv and gconv_net
            obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        if self.gconv_net is not None:
            obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        boxes_pred = self.box_net(obj_vecs)

        masks_pred = None
        if self.mask_net is not None:
            mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
            masks_pred = mask_scores.squeeze(1).sigmoid()

        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        # generate scene layout from bounding box and masks.
        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                     obj_to_img, H, W)

        # TODO:??? what is  layout_noise_dim
        if self.layout_noise_dim > 0:
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape,
                                       dtype=layout.dtype,
                                       device=layout.device)
            layout = torch.cat([layout, layout_noise], dim=1)
        # TODO: layout is bbox or segmentation mask ??
        img = self.refinement_net(layout)
        return img, boxes_pred, masks_pred, rel_scores
Example #11
0
    def forward(self, imgs, img_offset, gt_boxes, gt_classes, gt_fmaps):
        # forward detector
        # with timeit('detector forward', self.args.timing):
        #     result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
        #                                train_anchor_inds, return_fmap=True)
        # if result.is_none():
        #     return ValueError("heck")

        # forward generator
        # imgs = F.interpolate(x, size=self.args.image_size)
        # objs = result.obj_preds
        # boxes = result.rm_box_priors / BOX_SCALE
        # obj_to_img = result.im_inds - image_offset
        # obj_fmap = result.obj_fmap
        #
        # # check if all image have detection
        # cnt = torch.zeros(len(imgs)).byte()
        # cnt[obj_to_img] += 1
        # if (cnt > 0).sum() != len(imgs):
        #     print("some imgs have no detection")
        #     print(cnt)
        #     imgs = imgs[cnt]
        #     obj_to_img_new = obj_to_img.clone()
        #     for i in range(len(cnt)):
        #         if cnt[i] == 0:
        #             obj_to_img_new -= (obj_to_img > i).long()
        #     obj_to_img = obj_to_img_new

        obj_to_img = gt_classes[:, 0] - img_offset
        # print("obj_to_img.min(), obj_to_img.max(), len(imgs) {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs)))
        assert obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs), \
            "obj_to_img.min() >= 0 and obj_to_img.max() < len(imgs) is not satidfied: {} {} {}".format(obj_to_img.min(), obj_to_img.max(), len(imgs))
        boxes = gt_boxes
        obj_fmaps = gt_fmaps
        objs = gt_classes[:, 1]

        if self.args is not None:
            if self.args.exchange_feat_cls:
                print("exchange feature vectors and classes among bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    # permute = ind[torch.randperm(len(ind))]
                    # obj_fmaps[ind] = obj_fmaps[permute]
                    permute_ind = ind[torch.randperm(len(ind))[:2]]
                    permute = permute_ind[[1, 0]]
                    obj_fmaps[permute_ind] = obj_fmaps[permute]
                    objs[permute_ind] = objs[permute]

            if self.args.change_bbox:
                print("change the position of bboxes")
                for img_ind in range(imgs.shape[0]):
                    ind = (obj_to_img == img_ind).nonzero()[:, 0]
                    ind = ind[torch.randperm(len(ind))[0]]
                    if boxes[ind][3] < 0.8:
                        print("move to bottom")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                    elif boxes[ind][1] > 0.2:
                        print("move to top")
                        boxes[ind][3] -= boxes[ind][1]
                        boxes[ind][1] = 0
                    elif boxes[ind][0] > 0.2:
                        print("move to left")
                        boxes[ind][2] -= boxes[ind][0]
                        boxes[ind][0] = 0
                    elif boxes[ind][2] < 0.8:
                        print("move to right")
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1
                    else:
                        print("move to bottom right")
                        boxes[ind][1] += (1 - boxes[ind][3])
                        boxes[ind][3] = 1
                        boxes[ind][0] += (1 - boxes[ind][2])
                        boxes[ind][2] = 1

        mask_noise_indexes = torch.randperm(
            imgs.shape[0])[:int(self.args.noise_mask_ratio *
                                imgs.shape[0])].to(imgs.device)
        if len(mask_noise_indexes) == 0:
            mask_noise_indexes = None

        H, W = self.args.image_size
        fg_layout = boxes_to_layout(
            torch.ones(boxes.shape[0], 3).to(imgs.device), boxes, obj_to_img,
            H, W)
        bg_layout = (fg_layout <= 0).type(imgs.dtype)

        if self.forward_G:
            with timeit('generator forward', self.args.timing):
                imgs_pred, layout = self.model(obj_to_img,
                                               boxes,
                                               obj_fmaps,
                                               mask_noise_indexes,
                                               bg_layout=bg_layout)

        layout = layout.detach()
        if self.args.condition_d_img_on_class_label_map:
            layout = boxes_to_layout(
                (objs + 1).view(-1, 1).repeat(1, 3).type(imgs.dtype), boxes,
                obj_to_img, H, W)

        g_scores_fake_crop, g_obj_scores_fake_crop, g_rec_feature_fake_crop = None, None, None
        g_scores_fake_img = None
        g_scores_fake_bg = None
        if self.calc_G_D_loss:
            # forward discriminators to train generator
            if self.obj_discriminator is not None:
                with timeit('d_obj forward for g', self.args.timing):
                    g_scores_fake_crop, g_obj_scores_fake_crop, _, g_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_pred, objs, boxes, obj_to_img)

            if self.img_discriminator is not None:
                with timeit('d_img forward for g', self.args.timing):
                    if self.args.condition_d_img:
                        g_scores_fake_img = self.img_discriminator(
                            imgs_pred, layout)
                    else:
                        g_scores_fake_img = self.img_discriminator(imgs_pred)

            if self.bg_discriminator is not None:
                with timeit('d_bg forward for g', self.args.timing):
                    if self.args.condition_d_bg:
                        g_scores_fake_bg = self.bg_discriminator(
                            imgs_pred, bg_layout)
                    else:
                        g_scores_fake_bg = self.bg_discriminator(imgs_pred *
                                                                 bg_layout)

        d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = None, None, None, None
        d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = None, None, None, None
        d_obj_gp = None
        d_scores_fake_img = None
        d_scores_real_img = None
        d_img_gp = None
        d_scores_fake_bg = None
        d_scores_real_bg = None
        d_bg_gp = None
        if self.forward_D:
            # forward discriminators to train discriminators
            if self.obj_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_obj forward for d', self.args.timing):
                    d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = \
                        self.obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
                    d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = \
                        self.obj_discriminator(imgs, objs, boxes, obj_to_img)
                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        d_obj_gp = gradient_penalty(
                            real_crops.detach(), fake_crops.detach(),
                            self.obj_discriminator.discriminator)

            if self.img_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_img forward for d', self.args.timing):
                    if self.args.condition_d_img:
                        d_scores_fake_img = self.img_discriminator(
                            imgs_fake, layout)
                        d_scores_real_img = self.img_discriminator(
                            imgs, layout)
                    else:
                        d_scores_fake_img = self.img_discriminator(imgs_fake)
                        d_scores_real_img = self.img_discriminator(imgs)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_img:
                            d_img_gp = gradient_penalty(
                                torch.cat([imgs, layout], dim=1),
                                torch.cat([imgs_fake, layout], dim=1),
                                self.img_discriminator)
                        else:
                            d_img_gp = gradient_penalty(
                                imgs, imgs_fake, self.img_discriminator)

            if self.bg_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_bg forward for d', self.args.timing):
                    if self.args.condition_d_bg:
                        d_scores_fake_bg = self.bg_discriminator(
                            imgs_fake, bg_layout)
                        d_scores_real_bg = self.bg_discriminator(
                            imgs, bg_layout)
                    else:
                        d_scores_fake_bg = self.bg_discriminator(imgs_fake *
                                                                 bg_layout)
                        d_scores_real_bg = self.bg_discriminator(imgs *
                                                                 bg_layout)

                    if self.args.gan_loss_type == "wgan-gp" and self.training:
                        if self.args.condition_d_bg:
                            d_bg_gp = gradient_penalty(
                                torch.cat([imgs, bg_layout], dim=1),
                                torch.cat([imgs_fake, bg_layout], dim=1),
                                self.bg_discriminator)
                        else:
                            d_bg_gp = gradient_penalty(imgs * bg_layout,
                                                       imgs_fake * bg_layout,
                                                       self.bg_discriminator)
        return Result(
            imgs=imgs,
            imgs_pred=imgs_pred,
            objs=objs,
            obj_fmaps=obj_fmaps,
            boxes=boxes,
            obj_to_img=obj_to_img + img_offset,
            g_scores_fake_crop=g_scores_fake_crop,
            g_obj_scores_fake_crop=g_obj_scores_fake_crop,
            g_scores_fake_img=g_scores_fake_img,
            d_scores_fake_crop=d_scores_fake_crop,
            d_obj_scores_fake_crop=d_obj_scores_fake_crop,
            d_scores_real_crop=d_scores_real_crop,
            d_obj_scores_real_crop=d_obj_scores_real_crop,
            d_scores_fake_img=d_scores_fake_img,
            d_scores_real_img=d_scores_real_img,
            d_obj_gp=d_obj_gp,
            d_img_gp=d_img_gp,
            fake_crops=fake_crops,
            real_crops=real_crops,
            mask_noise_indexes=(mask_noise_indexes + img_offset)
            if mask_noise_indexes is not None else None,
            g_rec_feature_fake_crop=g_rec_feature_fake_crop,
            d_rec_feature_fake_crop=d_rec_feature_fake_crop,
            d_rec_feature_real_crop=d_rec_feature_real_crop,
            g_scores_fake_bg=g_scores_fake_bg,
            d_scores_fake_bg=d_scores_fake_bg,
            d_scores_real_bg=d_scores_real_bg,
            d_bg_gp=d_bg_gp,
            bg_layout=bg_layout,
        )
Example #12
0
    def forward(self,
                objs,
                triples,
                obj_to_img=None,
                pred_to_img=None,
                boxes_gt=None,
                masks_gt=None,
                lstm_hidden=None):
        """
    Required Inputs:
    - objs: LongTensor of shape (O,) giving categories for all objects
    - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o]
      means that there is a triple (objs[s], p, objs[o])

    Optional Inputs:
    - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i
      means that objects[o] is an object in image i. If not given then
      all objects are assumed to belong to the same image.
    - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing
      the spatial layout; if not given then use predicted boxes.
    - lstm_hidden: Tensor of shape (N, self.lstm_hid_dim)
    """
        O, T = objs.size(0), triples.size(0)
        s, p, o = triples.chunk(3, dim=1)  # All have shape (T, 1)
        s, p, o = [x.squeeze(1) for x in [s, p, o]]  # Now have shape (T,)
        edges = torch.stack([s, o], dim=1)  # Shape is (T, 2)

        if obj_to_img is None:
            obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device)

        obj_vecs = self.obj_embeddings(objs)
        obj_vecs_orig = obj_vecs
        pred_vecs = self.pred_embeddings(p)

        if isinstance(self.gconv, nn.Linear):
            obj_vecs = self.gconv(obj_vecs)
        else:
            obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
        if self.gconv_net is not None:
            obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)

        # Bounding boxes should be conditioned on context
        # because layout is finalized at this step
        context = None
        if self.context_network is not None:
            context, embedding = self.context_network(pred_vecs, pred_to_img)
            # Concatenate global context to each object depending on which image it is from
            # Probably not an efficient way to do this
            obj_with_context = torch.stack([
                torch.cat((obj_vecs[i], embedding[obj_to_img[i].item()]))
                for i in range(O)
            ])
            boxes_pred = self.box_net(obj_with_context)

            masks_pred = None
            if self.mask_net is not None:
                mask_scores = self.mask_net(obj_with_context.view(O, -1, 1, 1))
                masks_pred = mask_scores.squeeze(1).sigmoid()

        else:
            boxes_pred = self.box_net(obj_vecs)

            masks_pred = None
            if self.mask_net is not None:
                mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1))
                masks_pred = mask_scores.squeeze(1).sigmoid()

        s_boxes, o_boxes = boxes_pred[s], boxes_pred[o]
        s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o]
        rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1)
        rel_scores = self.rel_aux_net(rel_aux_input)

        H, W = self.image_size
        layout_boxes = boxes_pred if boxes_gt is None else boxes_gt

        if masks_pred is None:
            layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W)
        else:
            layout_masks = masks_pred if masks_gt is None else masks_gt
            layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks,
                                     obj_to_img, H, W)

        if lstm_hidden is not None:
            #print(lstm_hidden.size()[1],self.lstm_hid_dim)
            assert lstm_hidden.size()[0] == layout.size()[0]
            assert lstm_hidden.size()[1] == self.lstm_hid_dim
            lstm_embedding_vec = self.lstm_embedding(lstm_hidden)
            layout = torch.cat([layout, lstm_embedding_vec], dim=1)
        elif self.layout_noise_dim > 0:  # if not using lstm embedding
            N, C, H, W = layout.size()
            noise_shape = (N, self.layout_noise_dim, H, W)
            layout_noise = torch.randn(noise_shape,
                                       dtype=layout.dtype,
                                       device=layout.device)
            layout = torch.cat([layout, layout_noise], dim=1)
        img = self.refinement_net(layout)
        return img, boxes_pred, masks_pred, rel_scores, context