def forward(self, input, objects, objects_length): object_features = input context_features = self.context_feature_extract(input) relation_features = self.relation_feature_extract(input) outputs = list() objects_index = 0 for i in range(input.size(0)): box = objects[objects_index:objects_index + objects_length[i].item()] objects_index += objects_length[i].item() with torch.no_grad(): batch_ind = i + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device) # generate a "full-image" bounding box image_h, image_w = input.size(2) * self.downsample_rate, input.size(3) * self.downsample_rate image_box = torch.cat([ torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device), torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device), image_w + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device), image_h + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device) ], dim=-1) # meshgrid to obtain the subject and object bounding boxes sub_id, obj_id = jactorch.meshgrid(torch.arange(box.size(0), dtype=torch.int64, device=box.device), dim=0) sub_id, obj_id = sub_id.contiguous().view(-1), obj_id.contiguous().view(-1) sub_box, obj_box = jactorch.meshgrid(box, dim=0) sub_box = sub_box.contiguous().view(box.size(0) ** 2, 4) obj_box = obj_box.contiguous().view(box.size(0) ** 2, 4) # union box union_box = functional.generate_union_box(sub_box, obj_box) rel_batch_ind = i + torch.zeros(union_box.size(0), 1, dtype=box.dtype, device=box.device) # intersection maps box_context_imap = functional.generate_intersection_map(box, image_box, self.pool_size) sub_union_imap = functional.generate_intersection_map(sub_box, union_box, self.pool_size) obj_union_imap = functional.generate_intersection_map(obj_box, union_box, self.pool_size) this_context_features = self.context_roi_pool(context_features, torch.cat([batch_ind, image_box], dim=-1)) x, y = this_context_features.chunk(2, dim=1) this_object_features = self.object_feature_fuse(torch.cat([ self.object_roi_pool(object_features, torch.cat([batch_ind, box], dim=-1)), x, y * box_context_imap ], dim=1)) this_relation_features = self.relation_roi_pool(relation_features, torch.cat([rel_batch_ind, union_box], dim=-1)) x, y, z = this_relation_features.chunk(3, dim=1) this_relation_features = self.relation_feature_fuse(torch.cat([ this_object_features[sub_id], this_object_features[obj_id], x, y * sub_union_imap, z * obj_union_imap ], dim=1)) outputs.append([ None, self._norm(self.object_feature_fc(this_object_features.view(box.size(0), -1))), self._norm(self.relation_feature_fc(this_relation_features.view(box.size(0) * box.size(0), -1)).view(box.size(0), box.size(0), -1)) ]) return outputs
def gen_voronoi(centers, height, width): range_y = torch.arange(height, device=centers.device) range_x = torch.arange(width, device=centers.device) y, x = jactorch.meshgrid(range_y, range_x, dim=0) y, x = y.reshape(-1), x.reshape(-1) coords = torch.stack([y, x], dim=1).float() coords, centers = jactorch.meshgrid(coords, centers, dim=0) dis = (coords[:, :, 0] - centers[:, :, 1]) ** 2 + (coords[:, :, 1] - centers[:, :, 0]) ** 2 assignment = dis.argmin(1) return dis.view((height, width, -1)), assignment.view((height, width))
def cross_similarity(self, query, identifier): mapping = self.get_attribute(identifier) query = mapping(query) query = query / query.norm(2, dim=-1, keepdim=True) q1, q2 = jactorch.meshgrid(query, dim=-2) return self.similarity2(q1, q2, identifier, _normalized=True)
def generate_roi_pool_bins(box, bin_size, c2l=COOR_TO_LEN_CORR): # TODO(Jiayuan Mao @ 07/20): workaround: line space is not implemented for cuda. linspace = torch.linspace(0, 1, bin_size + 1, dtype=box.dtype).to(device=box.device) for i in range(box.dim() - 1): linspace.unsqueeze_(0) x_space = linspace * (__last(box, 2) - __last(box, 0) + c2l).unsqueeze(-1) + __last(box, 0).unsqueeze(-1) y_space = linspace * (__last(box, 3) - __last(box, 1) + c2l).unsqueeze(-1) + __last(box, 1).unsqueeze(-1) x1, x2 = x_space[:, :-1], x_space[:, 1:] - c2l y1, y2 = y_space[:, :-1], y_space[:, 1:] - c2l y1, x1 = jactorch.meshgrid(y1, x1, dim=-1) y2, x2 = jactorch.meshgrid(y2, x2, dim=-1) # shape: nr_boxes, bin_size^2, 4 bins = torch.stack([x1, y1, x2, y2], dim=-1).view(box.size(0), -1, 4) return bins.float()
def forward(self, input, image, masks=None): object_features = input #[batch_size,feature_dim,h_f,w_f] context_features = self.context_feature_extract( input) #[batch_size,feature_dim,h_f,w_f] relation_features = self.relation_feature_extract( input) #[batch_size,feature_dim//2*3,h_f,w_f] masks_monet = self.monet_mask_extract( self.image_resize(image)) # [batch_size,slot_num,h_m,w_m] if masks is None: masks = masks_monet else: masks = self.true_mask_resize( masks.view(input.shape[0] * self.slot_num, 1, self.h_raw_raw, self.w_raw_raw)) masks = masks.view(input.shape[0], self.slot_num, self.h_m, self.w_m) self.monet_mask_extract.m = masks if self.loss_type == 'separate': masks = masks.detach() masks = self.mask_resize(masks.view(-1, 1, self.h_m, self.w_m)).view( input.shape[0], -1, self.h_f, self.w_f) sub_id, obj_id = jactorch.meshgrid(torch.arange(self.slot_num, dtype=torch.long, device=input.device), dim=0) sub_id, obj_id = sub_id.contiguous().view( -1), obj_id.contiguous().view(-1) masked_object_features = object_features.unsqueeze( 1) * masks.unsqueeze(2) #[batch_size,slot_num,feature_dim,h_f,w_f] masked_context_features = context_features.unsqueeze( 1) * masks.unsqueeze(2) masked_relation_features = relation_features.unsqueeze(1) * ( masks[:, sub_id] + masks[:, obj_id]).unsqueeze(2) x_context, y_context = masked_context_features.chunk(2, dim=2) combined_object_features = torch.cat([ masked_object_features, x_context, y_context * masks.unsqueeze(2) ], dim=2) combined_object_features = combined_object_features.view( -1, self.feature_dim * 2, self.h_f, self.w_f) combined_object_features = self.object_feature_fuse( combined_object_features) combined_object_features = combined_object_features.view( input.shape[0], self.slot_num, self.output_dims[1], self.h_f, self.w_f) x_relation, y_relation, z_relation = masked_relation_features.chunk( 3, dim=2) combined_relation_features = torch.cat([ combined_object_features[:, sub_id], combined_object_features[:, obj_id], x_relation, y_relation * masks[:, sub_id].unsqueeze(2), z_relation * masks[:, obj_id].unsqueeze(2) ], dim=2) combined_relation_features = combined_relation_features.view( -1, self.feature_dim // 2 * 3 + self.output_dims[1] * 2, self.h_f, self.w_f) combined_relation_features = self.relation_feature_fuse( combined_relation_features) combined_object_features = combined_object_features.view( masks.shape[0] * masks.shape[1], -1) combined_object_features = self._norm( self.object_feature_fc(combined_object_features)) combined_object_features = combined_object_features.view( masks.shape[0], masks.shape[1], -1) combined_relation_features = combined_relation_features.view( masks.shape[0] * masks.shape[1]**2, -1) combined_relation_features = self._norm( self.object_feature_fc(combined_relation_features)) combined_relation_features = combined_relation_features.view( masks.shape[0], masks.shape[1], masks.shape[1], -1) outputs = [] for i in range(input.shape[0]): outputs.append([ None, combined_object_features[i], combined_relation_features[i] ]) return outputs
def forward(self, input, feed_dict, mode=0, tar_obj_id=-1): """ extracting region, collision and box sequence features for models 0 for normal, 1 for future and 2 for counterfact """ object_features = input context_features = self.context_feature_extract(input) relation_features = self.relation_feature_extract(input) outputs = list() #pdb.set_trace() def parse_boxes_for_frm(feed_dict, frm_idx, mode=0, tar_obj_id=-1): if mode == 0: boxes_list = [] tube_id_list = [] frm_id = feed_dict['tube_info']['frm_list'][frm_idx] for tube_id, tube_info in feed_dict['tube_info'].items(): if not isinstance(tube_id, int): continue assert len(tube_info['frm_name']) == len( tube_info['boxes']) if frm_id not in tube_info['frm_name']: continue box_idx = tube_info['frm_name'].index(frm_id) box = tube_info['boxes'][box_idx].squeeze(0) boxes_list.append( torch.tensor(box, device=feed_dict['img'].device)) tube_id_list.append(tube_id) boxes_tensor = torch.stack(boxes_list, 0).cuda() return boxes_tensor, tube_id_list elif mode == 1 or mode == 3: boxes_list = [] tube_id_list = [] frm_id = feed_dict['predictions']['frm_list'][frm_idx] for tube_id, tube_info in feed_dict['predictions'].items(): if not isinstance(tube_id, int): continue assert len(tube_info['frm_name']) == len( tube_info['boxes']) if frm_id not in tube_info['frm_name']: continue box_idx = tube_info['frm_name'].index(frm_id) box = tube_info['boxes'][box_idx] boxes_list.append( torch.tensor(box, device=feed_dict['img_future'].device)) tube_id_list.append(tube_id) boxes_tensor = torch.stack(boxes_list, 0).cuda() return boxes_tensor, tube_id_list elif mode == 2: boxes_list = [] tube_id_list = [] frm_id = feed_dict['counterfacts'][tar_obj_id]['frm_list'][ frm_idx] for tube_id, tube_info in feed_dict['counterfacts'][ tar_obj_id].items(): if not isinstance(tube_id, int): continue assert len(tube_info['frm_name']) == len( tube_info['boxes']) if frm_id not in tube_info['frm_name']: continue box_idx = tube_info['frm_name'].index(frm_id) box = tube_info['boxes'][box_idx] boxes_list.append( torch.tensor(box, device=feed_dict['img_counterfacts'] [tar_obj_id].device)) tube_id_list.append(tube_id) boxes_tensor = torch.stack(boxes_list, 0).cuda() return boxes_tensor, tube_id_list for i in range(input.size(0)): boxes, tube_id_list = parse_boxes_for_frm(feed_dict, i, mode, tar_obj_id) with torch.no_grad(): batch_ind = i + torch.zeros( boxes.size(0), 1, dtype=boxes.dtype, device=boxes.device) # generate a "full-image" bounding box image_h, image_w = input.size( 2) * self.downsample_rate, input.size( 3) * self.downsample_rate image_box = torch.cat([ torch.zeros(boxes.size(0), 1, dtype=boxes.dtype, device=boxes.device), torch.zeros(boxes.size(0), 1, dtype=boxes.dtype, device=boxes.device), image_w + torch.zeros(boxes.size(0), 1, dtype=boxes.dtype, device=boxes.device), image_h + torch.zeros(boxes.size(0), 1, dtype=boxes.dtype, device=boxes.device) ], dim=-1) # meshgrid to obtain the subject and object bounding boxes sub_id, obj_id = jactorch.meshgrid(torch.arange( boxes.size(0), dtype=torch.int64, device=boxes.device), dim=0) sub_id, obj_id = sub_id.contiguous().view( -1), obj_id.contiguous().view(-1) sub_box, obj_box = jactorch.meshgrid(boxes, dim=0) sub_box = sub_box.contiguous().view(boxes.size(0)**2, 4) obj_box = obj_box.contiguous().view(boxes.size(0)**2, 4) # union box union_box = functional.generate_union_box(sub_box, obj_box) rel_batch_ind = i + torch.zeros(union_box.size(0), 1, dtype=boxes.dtype, device=boxes.device) # intersection maps box_context_imap = functional.generate_intersection_map( boxes, image_box, self.pool_size) sub_union_imap = functional.generate_intersection_map( sub_box, union_box, self.pool_size) obj_union_imap = functional.generate_intersection_map( obj_box, union_box, self.pool_size) this_context_features = self.context_roi_pool( context_features, torch.cat([batch_ind, image_box], dim=-1)) x, y = this_context_features.chunk(2, dim=1) this_object_features = self.object_feature_fuse( torch.cat([ self.object_roi_pool(object_features, torch.cat([batch_ind, boxes], dim=-1)), x, y * box_context_imap ], dim=1)) this_relation_features = self.relation_roi_pool( relation_features, torch.cat([rel_batch_ind, union_box], dim=-1)) x, y, z = this_relation_features.chunk(3, dim=1) this_relation_features = self.relation_feature_fuse( torch.cat([ this_object_features[sub_id], this_object_features[obj_id], x, y * sub_union_imap, z * obj_union_imap ], dim=1)) if DEBUG: outputs.append( [None, this_object_features, this_relation_features]) else: outputs.append([ None, self._norm( self.object_feature_fc( this_object_features.view(boxes.size(0), -1))), self._norm( self.relation_feature_fc( this_relation_features.view( boxes.size(0) * boxes.size(0), -1)).view(boxes.size(0), boxes.size(0), -1)), tube_id_list ]) outputs_new = self.merge_tube_obj_ftr(outputs, feed_dict, mode, tar_obj_id) return outputs_new
def forward(self, input, objects, objects_length): object_features = input if self.object_supervision and self.concatenative_pair_representation: context_features = self.context_feature_extract(input) outputs = list() objects_index = 0 for i in range(input.size(0)): box = objects[objects_index:objects_index + objects_length[i].item()] #box is a list of object boundaries for the image objects_index += objects_length[i].item() with torch.no_grad(): batch_ind = i + torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device) # generate a "full-image" bounding box image_h, image_w = input.size( 2) * self.downsample_rate, input.size( 3) * self.downsample_rate image_box = torch.cat([ torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device), torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device), image_w + torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device), image_h + torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device) ], dim=-1) # intersection maps box_context_imap = functional.generate_intersection_map( box, image_box, self.pool_size) this_context_features = self.context_roi_pool( context_features, torch.cat([batch_ind, image_box], dim=-1)) x, y = this_context_features.chunk(2, dim=1) this_object_features = self.object_feature_fuse( torch.cat([ self.object_roi_pool( object_features, torch.cat([batch_ind, box], dim=-1)), x, y * box_context_imap ], dim=1)) object_representations = self._norm( self.object_feature_fc( this_object_features.view(box.size(0), -1))) object_pair_representations = self.objects_to_pair_representations( object_representations) if DEBUG: outputs.append([None, this_object_features, None]) else: outputs.append([ None, object_representations, object_pair_representations ]) elif self.object_supervision and not self.concatenative_pair_representation: object_features = input context_features = self.context_feature_extract(input) relation_features = self.relation_feature_extract(input) outputs = list() objects_index = 0 for i in range(input.size(0)): box = objects[objects_index:objects_index + objects_length[i].item()] objects_index += objects_length[i].item() with torch.no_grad(): batch_ind = i + torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device) # generate a "full-image" bounding box image_h, image_w = input.size( 2) * self.downsample_rate, input.size( 3) * self.downsample_rate image_box = torch.cat([ torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device), torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device), image_w + torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device), image_h + torch.zeros( box.size(0), 1, dtype=box.dtype, device=box.device) ], dim=-1) # meshgrid to obtain the subject and object bounding boxes sub_id, obj_id = jactorch.meshgrid(torch.arange( box.size(0), dtype=torch.int64, device=box.device), dim=0) sub_id, obj_id = sub_id.contiguous().view( -1), obj_id.contiguous().view(-1) sub_box, obj_box = jactorch.meshgrid(box, dim=0) sub_box = sub_box.contiguous().view(box.size(0)**2, 4) obj_box = obj_box.contiguous().view(box.size(0)**2, 4) # union box union_box = functional.generate_union_box(sub_box, obj_box) rel_batch_ind = i + torch.zeros(union_box.size(0), 1, dtype=box.dtype, device=box.device) # intersection maps box_context_imap = functional.generate_intersection_map( box, image_box, self.pool_size) sub_union_imap = functional.generate_intersection_map( sub_box, union_box, self.pool_size) obj_union_imap = functional.generate_intersection_map( obj_box, union_box, self.pool_size) this_context_features = self.context_roi_pool( context_features, torch.cat([batch_ind, image_box], dim=-1)) x, y = this_context_features.chunk(2, dim=1) this_object_features = self.object_feature_fuse( torch.cat([ self.object_roi_pool( object_features, torch.cat([batch_ind, box], dim=-1)), x, y * box_context_imap ], dim=1)) this_relation_features = self.relation_roi_pool( relation_features, torch.cat([rel_batch_ind, union_box], dim=-1)) x, y, z = this_relation_features.chunk(3, dim=1) this_relation_features = self.relation_feature_fuse( torch.cat([ this_object_features[sub_id], this_object_features[obj_id], x, y * sub_union_imap, z * obj_union_imap ], dim=1)) print( self._norm( self.relation_feature_fc( this_relation_features.view( box.size(0) * box.size(0), -1)).view(box.size(0), box.size(0), -1)).size()) if DEBUG: outputs.append( [None, this_object_features, this_relation_features]) else: outputs.append([ None, self._norm( self.object_feature_fc( this_object_features.view(box.size(0), -1))), self._norm( self.relation_feature_fc( this_relation_features.view( box.size(0) * box.size(0), -1)).view(box.size(0), box.size(0), -1)) ]) elif not self.object_supervision and self.concatenative_pair_representation: outputs = list() #object_features has shape batch_size x 256 x 16 x 24 obj_coord_map = coord_map( (object_features.size(2), object_features.size(3)), self.query.device) for i in range(input.size(0)): single_scene_object_features = torch.squeeze( object_features[i, :], dim=0) #dim=256 x 16 x 24 scene_object_coords = torch.unsqueeze(torch.cat( (single_scene_object_features, obj_coord_map), dim=0), dim=0) fused_object_coords = torch.squeeze( self.object_coord_fuse(scene_object_coords), dim=0) #dim=256 x Z x Y num_objects = objects_length[i].item() relevant_queries = self.query[ 0:num_objects, :] #num_objects x feature_dim attention_map = self.temperature * torch.einsum( "ij,jkl -> ikl", relevant_queries, fused_object_coords) #dim=num_objects x Z x Y attention_map = nn.Softmax(1)(attention_map.view( num_objects, -1)).view_as(attention_map) object_values = torch.einsum( "ijk,ljk -> il", attention_map, fused_object_coords) #dim=num_objects x 256 object_representations = self._norm( self.object_features_layer(object_values)) object_pair_representations = self.objects_to_pair_representations( object_representations) outputs.append([ None, object_representations, object_pair_representations ]) return outputs