def forward(self, objs, layout_boxes, layout_masks, test_mode=False): obj_vecs = self.attribute_embedding.forward(objs) # [B, N, d'] seg_batches = [] for b in range(obj_vecs.size(0)): mask = remove_dummy_objects(objs[b], self.opt.vocab) objs_vecs_batch = obj_vecs[b][mask] layout_boxes_batch = layout_boxes[b][mask] # Masks Layout if layout_masks is not None: layout_masks_batch = layout_masks[b][mask] seg = masks_to_layout(objs_vecs_batch, layout_boxes_batch, layout_masks_batch, self.opt.image_size[0], self.opt.image_size[0], test_mode=test_mode) else: # Boxes Layout seg = boxes_to_layout(objs_vecs_batch, layout_boxes_batch, self.opt.image_size[0], self.opt.image_size[0]) seg_batches.append(seg) seg = torch.cat(seg_batches, dim=0) # we downsample segmap and run convolution x = F.interpolate(seg, size=(self.sh, self.sw)) x = self.fc(x) x = self.head_0(x, seg) x = self.up(x) x = self.G_middle_0(x, seg) if self.opt.num_upsampling_layers == 'more' or \ self.opt.num_upsampling_layers == 'most': x = self.up(x) x = self.G_middle_1(x, seg) x = self.up(x) x = self.up_0(x, seg) x = self.up(x) x = self.up_1(x, seg) x = self.up(x) x = self.up_2(x, seg) x = self.up(x) x = self.up_3(x, seg) if self.opt.num_upsampling_layers == 'most': x = self.up(x) x = self.up_4(x, seg) x = self.conv_img(F.leaky_relu(x, 2e-1)) x = F.tanh(x) return x
def forward(self, img, objs, layout_boxes, layout_masks=None, gt_train=True, fool=False): obj_vecs = self.attribute_embedding.forward(objs) # [B, N, d'] # Masks Layout seg_batches = [] for b in range(obj_vecs.size(0)): mask = remove_dummy_objects(objs[b], self.opt.vocab) objs_vecs_batch = obj_vecs[b][mask] layout_boxes_batch = layout_boxes[b][mask] # Masks Layout if layout_masks is not None: layout_masks_batch = layout_masks[b][mask] seg = masks_to_layout( objs_vecs_batch, layout_boxes_batch, layout_masks_batch, self.opt.image_size[0], self.opt.image_size[0], test_mode=False) # test mode always false in disc. else: # Boxes Layout seg = boxes_to_layout(objs_vecs_batch, layout_boxes_batch, self.opt.image_size[0], self.opt.image_size[0]) seg_batches.append(seg) # layout = torch.cat(layout_batches, dim=0) # [B, N, d'] seg = torch.cat(seg_batches, dim=0) input = torch.cat([img, seg], dim=1) result = [] get_intermediate_features = not self.opt.no_ganFeat_loss for name, D in self.named_children(): if name.startswith('discriminator'): out = D(input) if not get_intermediate_features: out = [out] result.append(out) input = self.downsample(input) return result
def forward(self, objs, triples, obj_to_img=None, boxes_gt=None, masks_gt=None): """ Required Inputs: - objs: LongTensor of shape (O,) giving categories for all objects - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] means that there is a triple (objs[s], p, objs[o]) Optional Inputs: - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i means that objects[o] is an object in image i. If not given then all objects are assumed to belong to the same image. - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing the spatial layout; if not given then use predicted boxes. """ O, T = objs.size(0), triples.size(0) s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) edges = torch.stack([s, o], dim=1) # Shape is (T, 2) if obj_to_img is None: obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) obj_vecs = self.obj_embeddings( objs) # 'objs' => indices for model.vocab['object_idx_to_name'] obj_vecs_orig = obj_vecs pred_vecs = self.pred_embeddings( p) # 'p' => indices for model.vocab['pred_idx_to_name'] if isinstance(self.gconv, nn.Linear): obj_vecs = self.gconv(obj_vecs) else: obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) if self.gconv_net is not None: obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) # bounding box prediction boxes_pred_info = None if self.use_bbox_info: # bounding box prediction + predicted box info boxes_pred_info = self.box_net(obj_vecs) boxes_pred = boxes_pred_info[:, 0: 4] # first 4 entries are bbox coords else: boxes_pred = self.box_net(obj_vecs) masks_pred = None layout_masks = None if self.mask_net is not None: mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1)) masks_pred = mask_scores.squeeze(1).sigmoid() # this only affects training if loss is non-zero s_boxes, o_boxes = boxes_pred[s], boxes_pred[o] s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o] s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o] # uses predicted subject/object boxes, original subject/object embedding (input to GCNN) ## use untrained embedding vectors ##rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1) rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs_pred, o_vecs_pred], dim=1) rel_scores = self.rel_aux_net(rel_aux_input) # concatenate triplet vectors s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o] triplet_input = torch.cat([s_vecs_pred, pred_vecs, o_vecs_pred], dim=1) # triplet bounding boxes triplet_boxes_pred = None if self.triplet_box_net is not None: # predict 8 point bounding boxes triplet_boxes_pred = self.triplet_box_net(triplet_input) # triplet binary masks triplet_masks_pred = None if self.triplet_mask_net is not None: # input dimension must be [h, w, 1, 1] triplet_mask_scores = self.triplet_mask_net(triplet_input[:, :, None, None]) # only used for binary/masks CE loss #triplet_masks_pred = triplet_mask_scores.squeeze(1).sigmoid() triplet_masks_pred = triplet_mask_scores.squeeze(1) # triplet embedding triplet_embed = None if self.triplet_embed_net is not None: triplet_embed = self.triplet_embed_net(triplet_input) # triplet superbox triplet_superboxes_pred = None if self.triplet_superbox_net is not None: # predict 8 point bounding boxes triplet_superboxes_pred = self.triplet_superbox_net( triplet_input) # s/p/o (bboxes?) H, W = self.image_size layout_boxes = boxes_pred if boxes_gt is None else boxes_gt # compose layout mask if masks_pred is None: layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W) else: layout_masks = masks_pred if masks_gt is None else masks_gt layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks, obj_to_img, H, W) layout_crn = layout sg_context_pred = None sg_context_pred_d = None if self.sg_context_net is not None: N, C, H, W = layout.size() context = sg_context_to_layout(obj_vecs, obj_to_img, pooling=self.gcnn_pooling) sg_context_pred_sqz = self.sg_context_net(context) #### vector to spatial replication b = N s = self.sg_context_dim # b, s = sg_context_pred_sqz.size() sg_context_pred = sg_context_pred_sqz.view(b, s, 1, 1).expand( b, s, layout.size(2), layout.size(3)) layout_crn = torch.cat([layout, sg_context_pred], dim=1) ## discriminator uses different FC layer than the generator sg_context_predd_sqz = self.sg_context_net_d(context) s = self.sg_context_dim_d sg_context_pred_d = sg_context_predd_sqz.view(b, s, 1, 1).expand( b, s, layout.size(2), layout.size(3)) if self.layout_noise_dim > 0: N, C, H, W = layout.size() noise_shape = (N, self.layout_noise_dim, H, W) layout_noise = torch.randn(noise_shape, dtype=layout.dtype, device=layout.device) layout_crn = torch.cat([layout_crn, layout_noise], dim=1) # layout model only #img = self.refinement_net(layout_crn) img = None # compose triplet boxes using 'triplets', objs, etc. if boxes_gt is not None: s_boxes_gt, o_boxes_gt = boxes_gt[s], boxes_gt[o] triplet_boxes_gt = torch.cat([s_boxes_gt, o_boxes_gt], dim=1) else: triplet_boxes_gt = None #return img, boxes_pred, masks_pred, rel_scores return img, boxes_pred, masks_pred, objs, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, rel_scores, obj_vecs, pred_vecs, triplet_boxes_pred, triplet_boxes_gt, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred
def forward(self, obj_to_img, boxes_gt, obj_fmaps, mask_noise_indexes=None, masks_gt=None, bg_layout=None): """ Required Inputs: - objs: LongTensor of shape (O,) giving categories for all objects - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] means that there is a triple (objs[s], p, objs[o]) Optional Inputs: - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i means that objects[o] is an object in image i. If not given then all objects are assumed to belong to the same image. - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing the spatial layout; if not given then use predicted boxes. """ assert boxes_gt.max() < 1.1 and boxes_gt.min( ) > -0.1, "boxes_gt should be within range [0,1]" # O, T = objs.size(0), triples.size(0) # s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) # s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) # edges = torch.stack([s, o], dim=1) # Shape is (T, 2) # # if obj_to_img is None: # obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) # # obj_vecs = self.obj_embeddings(objs) # obj_vecs_orig = obj_vecs # pred_vecs = self.pred_embeddings(p) # # if isinstance(self.gconv, nn.Linear): # obj_vecs = self.gconv(obj_vecs) # else: # obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) # if self.gconv_net is not None: # obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) if self.args.not_decrease_feature_dimension: obj_vecs = obj_fmaps else: obj_vecs = self.obj_fmap_net(obj_fmaps) no_noise_obj_vecs = obj_vecs if self.args.object_noise_dim > 0: # select objs belongs to images in mask_noise_indexes if mask_noise_indexes is not None and self.training: mask_noise_obj_index_list = [] for ind in mask_noise_indexes: mask_noise_obj_index_list.append( (obj_to_img == ind).nonzero()) mask_noise_obj_indexes = torch.cat(mask_noise_obj_index_list, dim=0)[:, 0] if self.args.noise_apply_method == "concat": object_noise = torch.randn( (obj_vecs.shape[0], self.args.object_noise_dim), dtype=obj_vecs.dtype, device=obj_vecs.device) if mask_noise_indexes is not None and self.training: object_noise[mask_noise_obj_indexes] = 0 obj_vecs = torch.cat([obj_vecs, object_noise], dim=1) elif self.args.noise_apply_method == "add": object_noise = torch.randn(obj_vecs.shape, dtype=obj_vecs.dtype, device=obj_vecs.device) if mask_noise_indexes is not None and self.training: object_noise[mask_noise_obj_indexes] = 0 obj_vecs = obj_vecs + object_noise # boxes_pred = self.box_net(obj_vecs) masks_pred = None if self.mask_net is not None: mask_scores = self.mask_net( obj_vecs.view(obj_vecs.shape[0], -1, 1, 1)) masks_pred = mask_scores.squeeze(1).sigmoid() # s_boxes, o_boxes = boxes_pred[s], boxes_pred[o] # s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o] # rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1) # rel_scores = self.rel_aux_net(rel_aux_input) H, W = self.image_size # layout_boxes = boxes_pred if boxes_gt is None else boxes_gt layout_boxes = boxes_gt # layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W) if masks_pred is None: if self.args.object_no_noise_with_bbox: layout = boxes_to_layout(no_noise_obj_vecs, layout_boxes, obj_to_img, H, W) else: layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W) else: layout_masks = masks_pred if masks_gt is None else masks_gt if self.args.object_no_noise_with_mask: layout = masks_to_layout(no_noise_obj_vecs, layout_boxes, layout_masks, obj_to_img, H, W) else: layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks, obj_to_img, H, W) ret_layout = layout if self.layout_noise_dim > 0: N, C, H, W = layout.size() if self.args.noise_apply_method == "concat": noise_shape = (N, self.layout_noise_dim, H, W) elif self.args.noise_apply_method == "add": noise_shape = layout.shape # print("check noise_std here, it is %.10f" % self.args.noise_std) noise_std = torch.zeros(noise_shape, dtype=layout.dtype, device=layout.device).fill_( self.args.noise_std) layout_noise = torch.normal(mean=0.0, std=noise_std) if self.args.layout_noise_only_on_foreground: layout_noise *= (1 - bg_layout[:, :1, :, :].repeat( 1, self.layout_noise_dim, 1, 1)) if mask_noise_indexes is not None and self.training: layout_noise[mask_noise_indexes] = 0. # layout_noise = torch.randn(noise_shape, dtype=layout.dtype, # device=layout.device) if self.args.noise_apply_method == "concat": layout = torch.cat([layout, layout_noise], dim=1) elif self.args.noise_apply_method == "add": layout = layout + layout_noise img = self.refinement_net(layout) return img, ret_layout
def forward(self, objs, triples, obj_to_img=None, pred_to_img=None, boxes_gt=None, masks_gt=None): """ Required Inputs: - objs: LongTensor of shape (O,) giving categories for all objects - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] means that there is a triple (objs[s], p, objs[o]) Optional Inputs: - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i means that objects[o] is an object in image i. If not given then all objects are assumed to belong to the same image. - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing the spatial layout; if not given then use predicted boxes. """ O, T = objs.size(0), triples.size(0) s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) edges = torch.stack([s, o], dim=1) # Shape is (T, 2) if obj_to_img is None: obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) obj_vecs, pred_vecs = self.embedding(objs, p) obj_vecs_orig = obj_vecs obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) boxes_pred = self.box_net(obj_vecs) masks_pred = None if self.mask_net is not None: mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1)) masks_pred = mask_scores.squeeze(1).sigmoid() s_boxes, o_boxes = boxes_pred[s], boxes_pred[o] s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o] rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1) rel_scores = self.rel_aux_net(rel_aux_input) H, W = self.image_size layout_boxes = boxes_pred if boxes_gt is None else boxes_gt if masks_pred is None: layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W) else: layout_masks = masks_pred if masks_gt is None else masks_gt layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks, obj_to_img, H, W) # Add context embedding # context = self.context_network(pred_vecs) # TODO how to concatenate this? # if self.layout_noise_dim > 0: # N, C, H, W = layout.size() # # Concatenate noise with new context embedding and make proper shape # noise = torch.randn(N, self.layout_noise_dim) # noise = noise.view(noise.size(0), self.layout_noise_dim) # z = torch.cat([noise,proj_c],1) # layout_noise = self.noise_layout(z) # layout = torch.cat([layout, layout_noise], dim=1) if self.layout_noise_dim > 0: N, C, H, W = layout.size() noise_shape = (N, self.layout_noise_dim, H, W) layout_noise = torch.randn(noise_shape, dtype=layout.dtype, device=layout.device) layout = torch.cat([layout, layout_noise], dim=1) img = self.refinement_net(layout) return img, boxes_pred, masks_pred, rel_scores
def forward(self, objs, triples, obj_to_img=None, boxes_gt=None, masks_gt=None, tr_to_img=None): """ Required Inputs: - objs: LongTensor of shape (O,) giving categories for all objects - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] means that there is a triple (objs[s], p, objs[o]) Optional Inputs: - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i means that objects[o] is an object in image i. If not given then all objects are assumed to belong to the same image. - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing the spatial layout; if not given then use predicted boxes. """ O, T = objs.size(0), triples.size(0) s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) edges = torch.stack([s, o], dim=1) # Shape is (T, 2) if obj_to_img is None: obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) obj_vecs = self.obj_embeddings( objs) # 'objs' => indices for model.vocab['object_idx_to_name'] obj_vecs_orig = obj_vecs pred_vecs = self.pred_embeddings( p) # 'p' => indices for model.vocab['pred_idx_to_name'] pred_vecs_orig = pred_vecs if isinstance(self.gconv, nn.Linear): obj_vecs = self.gconv(obj_vecs) else: obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) if self.gconv_net is not None: obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) #### object context vectors ############### num_imgs = obj_to_img[obj_to_img.size(0) - 1] + 1 context_obj_vecs = torch.zeros(num_imgs, obj_vecs.size(1), dtype=obj_vecs.dtype, device=obj_vecs.device) obj_to_img_exp = obj_to_img.view(-1, 1).expand_as(obj_vecs) context_obj_vecs = context_obj_vecs.scatter_add( 0, obj_to_img_exp, obj_vecs) # get object counts obj_counts = torch.zeros(num_imgs, dtype=obj_vecs.dtype, device=obj_vecs.device) ones = torch.ones(obj_to_img.size(0), dtype=obj_vecs.dtype, device=obj_vecs.device) obj_counts = obj_counts.scatter_add(0, obj_to_img, ones) context_obj_vecs = context_obj_vecs / obj_counts.view(-1, 1) context_obj_vecs = context_obj_vecs[obj_to_img] context_obj_vecs = context_obj_vecs[s] #################################### ####### triplet context vectors ########### #context_tr_vecs = None # concatenate triplet vectors #triplets = torch.cat([obj_vecs[s], pred_vecs, obj_vecs[o]], dim=1) #context_tr_vecs = torch.zeros(num_imgs, triplets.size(1), dtype=obj_vecs.dtype, device=obj_vecs.device) # need triplet to image #tr_to_img_exp = tr_to_img.view(-1, 1).expand_as(triplets) #context_tr_vecs = context_tr_vecs.scatter_add(0, tr_to_img_exp, triplets) # get triplet counts #tr_counts = torch.zeros(num_imgs, dtype=obj_vecs.dtype, device=obj_vecs.device) #ones = torch.ones(triplets.size(0), dtype=obj_vecs.dtype, device=obj_vecs.device) #tr_counts = tr_counts.scatter_add(0, tr_to_img, ones) #context_tr_vecs = context_tr_vecs/tr_counts.view(-1,1) # dimension is (# triplets, 3*input_dim) #context_tr_vecs = context_tr_vecs[tr_to_img] # get some context! #context_tr_vecs = self.triplet_context_net(context_tr_vecs) ########################################### #### mask out some predicates ##### pred_mask_gt = None pred_mask_scores = None if self.use_masked_sg: perc = torch.FloatTensor([0.50]) # hyperparameter num_mask_objs = torch.floor(perc * len(s)).cpu().numpy()[0].astype(int) if num_mask_objs < 1: num_mask_objs = 1 mask_idx = torch.randint(0, len(s) - 1, (num_mask_objs, )) #rand_idx = torch.randperm(len(s)-1) #mask_idx = rand_idx[:num_mask_objs] # GT pred_mask_gt = p[mask_idx.long()] # return # set mask idx to masked embedding (e.g. new SG!) pred_vecs_copy = pred_vecs_orig ##### need to add i=46 None embedding pred_vecs_copy[mask_idx.long()] = self.pred_embeddings( torch.tensor([self.mask_pred]).cuda()) # convolve new masked SG if isinstance(self.gconv, nn.Linear): mask_obj_vecs = self.gconv(obj_vecs_orig) else: mask_obj_vecs, mask_pred_vecs = self.gconv( obj_vecs_orig, pred_vecs_copy, edges) if self.gconv_net is not None: mask_obj_vecs, mask_pred_vecs = self.gconv_net( mask_obj_vecs, mask_pred_vecs, edges) # subj/obj obj idx s_mask = s[mask_idx.long()] o_mask = o[mask_idx.long()] subj_vecs_mask = mask_obj_vecs[s_mask] obj_vecs_mask = mask_obj_vecs[o_mask] # predict masked predicate relationship pred_mask_input = torch.cat([subj_vecs_mask, obj_vecs_mask], dim=1) pred_mask_scores = self.pred_mask_net(pred_mask_input) ##################### # bounding box prediction boxes_pred_info = None if self.use_bbox_info: # bounding box prediction + predicted box info boxes_pred_info = self.box_net(obj_vecs) boxes_pred = boxes_pred_info[:, 0: 4] # first 4 entries are bbox coords else: boxes_pred = self.box_net(obj_vecs) masks_pred = None layout_masks = None if self.mask_net is not None: mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1)) masks_pred = mask_scores.squeeze(1).sigmoid() # predicted bboxes and embedding vectors s_boxes, o_boxes = boxes_pred[s], boxes_pred[o] s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o] # input embedding vectors s_vecs, o_vecs, p_vecs = obj_vecs_orig[s], obj_vecs_orig[ o], pred_vecs_orig input_tr_vecs = torch.cat([s_vecs, p_vecs, o_vecs], dim=1) # VSA (with obj/pred vectors of varying kinds) fr_obj_vecs = self.fr_obj_embeddings(objs) fr_pred_vecs = self.fr_pred_embeddings(p) fr_s_vecs, fr_o_vecs = fr_obj_vecs[s], fr_obj_vecs[o] mapc_bind = fr_s_vecs * fr_o_vecs * fr_pred_vecs # mapc_bind = s_vecs * o_vecs * p_vecs #mapc_bind = s_vecs_pred * o_vecs_pred * pred_vecs mapc_bind = F.normalize(mapc_bind, p=2, dim=1) # uses predicted subject/object boxes, original subject/object embedding (input to GCNN) ## use original embedding vectors rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1) rel_scores = self.rel_aux_net(rel_aux_input) # subject prediction subj_aux_input = torch.cat([s_boxes, o_boxes, p_vecs, o_vecs], dim=1) subj_scores = self.subj_aux_net(subj_aux_input) # object prediction obj_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, p_vecs], dim=1) obj_scores = self.obj_aux_net(obj_aux_input) # object class prediction (for output object vectors) obj_class_scores = self.obj_class_aux_net(obj_vecs) # relationship class prediction (for output object vectors) # relationship embedding (very small embedding) # augment relationship embedding use_augmentation = False mask_rel_embedding = None if use_augmentation: num_augs = 4 num_preds = len(p) mask_rel_embedding = torch.zeros( (num_preds, num_augs + 1, self.embedding_dim), dtype=obj_vecs.dtype, device=obj_vecs.device) rel_embedding = [] #pred_vecs_mask = np.zeros((num_preds,num_augs,self.embedding_dim)) # mask embedding perc_aug of vectors with 0 perc_aug = torch.FloatTensor([0.4]) # hyperparameter num_mask = torch.floor( perc_aug * len(pred_vecs[0])).cpu().numpy()[0].astype(int) #p_ids = [] for i in range(num_preds): pred = pred_vecs[i] #p_id = p[i] vecs = [pred] #p_ids += [p_id] for j in range(num_augs): # pick a random set of indices to zero-out rand_idx = torch.randperm(len( pred_vecs[0])) # 0-127 range shuffled mask_idx = rand_idx[:num_mask] pred_mask = pred.detach().clone() pred_mask[mask_idx] = 0.0 vecs += [pred_mask] # project masked augmented relationship vectors pred_mask = self.rel_embed_aux_net( torch.stack(vecs)) # output predicate embeddings pred_mask = F.normalize(pred_mask, dim=1) mask_rel_embedding[i, :, :] = pred_mask rel_embedding += [pred_mask[0]] #rel_embedding = torch.stack(rel_embedding) # projection head for supervised contrastive loss rel_embedding = self.rel_embed_aux_net( pred_vecs) # output projected predicate embeddings rel_embedding = F.normalize(rel_embedding, dim=1) # relationship class prediction on predicates rel_class_scores = self.rel_class_aux_net(pred_vecs) # concatenate triplet vectors s_vecs_pred, o_vecs_pred = obj_vecs[s], obj_vecs[o] triplet_input = torch.cat([s_vecs_pred, pred_vecs, o_vecs_pred], dim=1) # triplet bounding boxes triplet_boxes_pred = None if self.triplet_box_net is not None: # predict 8 point bounding boxes triplet_boxes_pred = self.triplet_box_net(triplet_input) # triplet binary masks triplet_masks_pred = None if self.triplet_mask_net is not None: # input dimension must be [h, w, 1, 1] triplet_mask_scores = self.triplet_mask_net(triplet_input[:, :, None, None]) # only used for binary/masks CE loss #triplet_masks_pred = triplet_mask_scores.squeeze(1).sigmoid() triplet_masks_pred = triplet_mask_scores.squeeze(1) # triplet embedding triplet_embed = None if self.triplet_embed_net is not None: triplet_embed = self.triplet_embed_net(triplet_input) # triplet superbox triplet_superboxes_pred = None if self.triplet_superbox_net is not None: # predict 2 point superboxes triplet_superboxes_pred = self.triplet_superbox_net( triplet_input) # s/p/o (bboxes?) # predicate grounding pred_ground = None if self.pred_ground_net is not None: # predict 2 point pred grounding pred_ground = self.pred_ground_net(pred_vecs) # s/p/o (bboxes?) # triplet context triplet_context_input = torch.cat( [context_obj_vecs, s_vecs_pred, pred_vecs, o_vecs_pred], dim=1) # output dimension is 384 context_tr_vecs = self.triplet_context_net(triplet_context_input) H, W = self.image_size layout_boxes = boxes_pred if boxes_gt is None else boxes_gt # compose layout mask if masks_pred is None: layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W) else: layout_masks = masks_pred if masks_gt is None else masks_gt layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks, obj_to_img, H, W) layout_crn = layout sg_context_pred = None sg_context_pred_d = None if self.sg_context_net is not None: N, C, H, W = layout.size() context = sg_context_to_layout(obj_vecs, obj_to_img, pooling=self.gcnn_pooling) sg_context_pred_sqz = self.sg_context_net(context) #### vector to spatial replication b = N s = self.sg_context_dim # b, s = sg_context_pred_sqz.size() sg_context_pred = sg_context_pred_sqz.view(b, s, 1, 1).expand( b, s, layout.size(2), layout.size(3)) layout_crn = torch.cat([layout, sg_context_pred], dim=1) ## discriminator uses different FC layer than the generator sg_context_predd_sqz = self.sg_context_net_d(context) s = self.sg_context_dim_d sg_context_pred_d = sg_context_predd_sqz.view(b, s, 1, 1).expand( b, s, layout.size(2), layout.size(3)) if self.layout_noise_dim > 0: N, C, H, W = layout.size() noise_shape = (N, self.layout_noise_dim, H, W) layout_noise = torch.randn(noise_shape, dtype=layout.dtype, device=layout.device) layout_crn = torch.cat([layout_crn, layout_noise], dim=1) # layout model only #img = self.refinement_net(layout_crn) img = None # compose triplet boxes using 'triplets', objs, etc. if boxes_gt is not None: s_boxes_gt, o_boxes_gt = boxes_gt[s], boxes_gt[o] triplet_boxes_gt = torch.cat([s_boxes_gt, o_boxes_gt], dim=1) else: triplet_boxes_gt = None #return img, boxes_pred, masks_pred, rel_scores return img, boxes_pred, masks_pred, objs, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, rel_scores, obj_vecs, pred_vecs, triplet_boxes_pred, triplet_boxes_gt, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred, obj_scores, pred_mask_gt, pred_mask_scores, context_tr_vecs, input_tr_vecs, obj_class_scores, rel_class_scores, subj_scores, rel_embedding, mask_rel_embedding, pred_ground #, mapc_bind
def forward(self, objs, triples, obj_to_img=None, boxes_gt=None, masks_gt=None): """ Required Inputs: - objs: LongTensor of shape (O,) giving categories for all objects - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] means that there is a triple (objs[s], p, objs[o]) main Process: graph >>> graph conv >>> layout >>> CRN. Optional Inputs: - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i means that objects[o] is an object in image i. If not given then all objects are assumed to belong to the same image. - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing the spatial layout; if not given then use predicted boxes. """ O, T = objs.size(0), triples.size(0) s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) # edges specify start and end. edges = torch.stack([s, o], dim=1) # Shape is (T, 2) if obj_to_img is None: obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) obj_vecs = self.obj_embeddings(objs) obj_vecs_orig = obj_vecs pred_vecs = self.pred_embeddings(p) # Graph convolutional network. if isinstance(self.gconv, nn.Linear): obj_vecs = self.gconv(obj_vecs) else: # what's the difference between gconv and gconv_net obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) if self.gconv_net is not None: obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) boxes_pred = self.box_net(obj_vecs) masks_pred = None if self.mask_net is not None: mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1)) masks_pred = mask_scores.squeeze(1).sigmoid() s_boxes, o_boxes = boxes_pred[s], boxes_pred[o] s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o] rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1) rel_scores = self.rel_aux_net(rel_aux_input) H, W = self.image_size layout_boxes = boxes_pred if boxes_gt is None else boxes_gt # generate scene layout from bounding box and masks. if masks_pred is None: layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W) else: layout_masks = masks_pred if masks_gt is None else masks_gt layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks, obj_to_img, H, W) # TODO:??? what is layout_noise_dim if self.layout_noise_dim > 0: N, C, H, W = layout.size() noise_shape = (N, self.layout_noise_dim, H, W) layout_noise = torch.randn(noise_shape, dtype=layout.dtype, device=layout.device) layout = torch.cat([layout, layout_noise], dim=1) # TODO: layout is bbox or segmentation mask ?? img = self.refinement_net(layout) return img, boxes_pred, masks_pred, rel_scores
def forward(self, objs, triples, obj_to_img=None, pred_to_img=None, boxes_gt=None, masks_gt=None, lstm_hidden=None): """ Required Inputs: - objs: LongTensor of shape (O,) giving categories for all objects - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] means that there is a triple (objs[s], p, objs[o]) Optional Inputs: - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i means that objects[o] is an object in image i. If not given then all objects are assumed to belong to the same image. - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing the spatial layout; if not given then use predicted boxes. - lstm_hidden: Tensor of shape (N, self.lstm_hid_dim) """ O, T = objs.size(0), triples.size(0) s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) edges = torch.stack([s, o], dim=1) # Shape is (T, 2) if obj_to_img is None: obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) obj_vecs = self.obj_embeddings(objs) obj_vecs_orig = obj_vecs pred_vecs = self.pred_embeddings(p) if isinstance(self.gconv, nn.Linear): obj_vecs = self.gconv(obj_vecs) else: obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) if self.gconv_net is not None: obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) # Bounding boxes should be conditioned on context # because layout is finalized at this step context = None if self.context_network is not None: context, embedding = self.context_network(pred_vecs, pred_to_img) # Concatenate global context to each object depending on which image it is from # Probably not an efficient way to do this obj_with_context = torch.stack([ torch.cat((obj_vecs[i], embedding[obj_to_img[i].item()])) for i in range(O) ]) boxes_pred = self.box_net(obj_with_context) masks_pred = None if self.mask_net is not None: mask_scores = self.mask_net(obj_with_context.view(O, -1, 1, 1)) masks_pred = mask_scores.squeeze(1).sigmoid() else: boxes_pred = self.box_net(obj_vecs) masks_pred = None if self.mask_net is not None: mask_scores = self.mask_net(obj_vecs.view(O, -1, 1, 1)) masks_pred = mask_scores.squeeze(1).sigmoid() s_boxes, o_boxes = boxes_pred[s], boxes_pred[o] s_vecs, o_vecs = obj_vecs_orig[s], obj_vecs_orig[o] rel_aux_input = torch.cat([s_boxes, o_boxes, s_vecs, o_vecs], dim=1) rel_scores = self.rel_aux_net(rel_aux_input) H, W = self.image_size layout_boxes = boxes_pred if boxes_gt is None else boxes_gt if masks_pred is None: layout = boxes_to_layout(obj_vecs, layout_boxes, obj_to_img, H, W) else: layout_masks = masks_pred if masks_gt is None else masks_gt layout = masks_to_layout(obj_vecs, layout_boxes, layout_masks, obj_to_img, H, W) if lstm_hidden is not None: #print(lstm_hidden.size()[1],self.lstm_hid_dim) assert lstm_hidden.size()[0] == layout.size()[0] assert lstm_hidden.size()[1] == self.lstm_hid_dim lstm_embedding_vec = self.lstm_embedding(lstm_hidden) layout = torch.cat([layout, lstm_embedding_vec], dim=1) elif self.layout_noise_dim > 0: # if not using lstm embedding N, C, H, W = layout.size() noise_shape = (N, self.layout_noise_dim, H, W) layout_noise = torch.randn(noise_shape, dtype=layout.dtype, device=layout.device) layout = torch.cat([layout, layout_noise], dim=1) img = self.refinement_net(layout) return img, boxes_pred, masks_pred, rel_scores, context