def __init__(self, classes, num_rels, mode='sgdet', embed_dim=200, pooling_dim=4096, use_bias=True): super(EndCell, self).__init__() self.classes = classes self.num_rels = num_rels assert mode in MODES self.embed_dim = embed_dim self.pooling_dim = pooling_dim self.use_bias = use_bias self.mode = mode self.ort_embedding = torch.autograd.Variable( get_ort_embeds(self.num_classes, self.embed_dim).cuda()) self.context = LC(classes=self.classes, mode=self.mode, embed_dim=self.embed_dim, obj_dim=self.pooling_dim) self.union_boxes = UnionBoxesAndFeats(pooling_size=7, stride=16, dim=512) self.pooling_size = 7 roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier self.post_lstm = nn.Linear(self.pooling_dim + self.embed_dim + 5, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.pooling_dim)) self.post_lstm.bias.data.zero_() self.post_emb = nn.Linear(self.pooling_dim + self.embed_dim + 5, self.pooling_dim * 2) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias()
def __init__(self, train_data, mode='sgdet', num_gpus=1, require_overlap_det=True, use_bias=False, test_bias=False, detector_model='baseline', RELS_PER_IMG=1024): """ :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param require_overlap_det: Whether two objects must intersect """ super(RelModelBase, self).__init__() self.classes = train_data.ind_to_classes self.rel_classes = train_data.ind_to_predicates self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.detector_model = detector_model self.RELS_PER_IMG = RELS_PER_IMG self.pooling_size = 7 self.stride = 16 self.obj_dim = 4096 self.use_bias = use_bias self.test_bias = test_bias self.require_overlap = require_overlap_det and self.mode == 'sgdet' if self.detector_model == 'mrcnn': print('\nLoading COCO pretrained model maskrcnn_resnet50_fpn...\n') # See https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html self.detector = torchvision.models.detection.maskrcnn_resnet50_fpn( pretrained=True, box_detections_per_img=50, box_score_thresh=0.2) in_features = self.detector.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one self.detector.roi_heads.box_predictor = FastRCNNPredictor( in_features, len(self.classes)) self.detector.roi_heads.mask_predictor = None self.union_boxes = UnionBoxesAndFeats( pooling_size=self.pooling_size, stride=self.stride, dim=256 if self.detector_model == 'mrcnn' else 512) if self.detector_model == 'mrcnn': layers = list(self.detector.roi_heads.children())[:2] self.roi_fmap_obj = copy.deepcopy(layers[1]) self.roi_fmap = copy.deepcopy(layers[1]) self.multiscale_roi_pool = copy.deepcopy(layers[0]) else: raise NotImplementedError(self.detector_model) if self.use_bias: self.freq_bias = FrequencyBias(train_data)
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.hook_for_grad = False self.gradients = [] self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.ort_embedding = torch.autograd.Variable( get_ort_embeds(self.num_classes, 200).cuda()) embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim) self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim) self.obj_embed.weight.data = embed_vecs.clone() # This probably doesn't help it much self.pos_embed = nn.Sequential(*[ nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0), nn.Linear(4, 128), nn.ReLU(inplace=True), nn.Dropout(0.1), ]) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) self.merge_obj_feats = nn.Sequential( nn.Linear(self.obj_dim + self.embed_dim + 128, self.hidden_dim), nn.ReLU()) # self.trans = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim//4), # LayerNorm(self.hidden_dim//4), nn.ReLU(), # nn.Linear(self.hidden_dim//4, self.hidden_dim)) self.get_phr_feats = nn.Linear(self.pooling_dim, self.hidden_dim) self.embeddings4lstm = nn.Embedding(self.num_classes, self.embed_dim) self.lstm = nn.LSTM(input_size=self.hidden_dim + self.embed_dim, hidden_size=self.hidden_dim, num_layers=1) self.obj_mps1 = Message_Passing4OBJ(self.hidden_dim) # self.obj_mps2 = Message_Passing4OBJ(self.hidden_dim) self.get_boxes_encode = Boxes_Encode(64) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### # self.obj_classify_head = nn.Linear(self.pooling_dim, self.num_classes) # self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim//2) # self.post_emb_s.weight = torch.nn.init.xavier_normal(self.post_emb_s.weight, gain=1.0) # self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim//2) # self.post_emb_o.weight = torch.nn.init.xavier_normal(self.post_emb_o.weight, gain=1.0) # self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim//2) # self.merge_obj_high.weight = torch.nn.init.xavier_normal(self.merge_obj_high.weight, gain=1.0) # self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim//2) # self.merge_obj_low.weight = torch.nn.init.xavier_normal(self.merge_obj_low.weight, gain=1.0) # self.rel_compress = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True) # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) # self.freq_gate = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True) # self.freq_gate.weight = torch.nn.init.xavier_normal(self.freq_gate.weight, gain=1.0) self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim) self.post_emb_s.weight = torch.nn.init.xavier_normal( self.post_emb_s.weight, gain=1.0) self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim) self.post_emb_o.weight = torch.nn.init.xavier_normal( self.post_emb_o.weight, gain=1.0) self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim) self.merge_obj_high.weight = torch.nn.init.xavier_normal( self.merge_obj_high.weight, gain=1.0) self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim) self.merge_obj_low.weight = torch.nn.init.xavier_normal( self.merge_obj_low.weight, gain=1.0) self.rel_compress = nn.Linear(self.pooling_dim + 64, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) self.freq_gate = nn.Linear(self.pooling_dim + 64, self.num_rels, bias=True) self.freq_gate.weight = torch.nn.init.xavier_normal( self.freq_gate.weight, gain=1.0) # self.ranking_module = nn.Sequential(nn.Linear(self.pooling_dim + 64, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, 1)) if self.use_bias: self.freq_bias = FrequencyBias()
class RelModel(nn.Module): """ RELATIONSHIPS """ def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.hook_for_grad = False self.gradients = [] self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.ort_embedding = torch.autograd.Variable( get_ort_embeds(self.num_classes, 200).cuda()) embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim) self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim) self.obj_embed.weight.data = embed_vecs.clone() # This probably doesn't help it much self.pos_embed = nn.Sequential(*[ nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0), nn.Linear(4, 128), nn.ReLU(inplace=True), nn.Dropout(0.1), ]) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) self.merge_obj_feats = nn.Sequential( nn.Linear(self.obj_dim + self.embed_dim + 128, self.hidden_dim), nn.ReLU()) # self.trans = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim//4), # LayerNorm(self.hidden_dim//4), nn.ReLU(), # nn.Linear(self.hidden_dim//4, self.hidden_dim)) self.get_phr_feats = nn.Linear(self.pooling_dim, self.hidden_dim) self.embeddings4lstm = nn.Embedding(self.num_classes, self.embed_dim) self.lstm = nn.LSTM(input_size=self.hidden_dim + self.embed_dim, hidden_size=self.hidden_dim, num_layers=1) self.obj_mps1 = Message_Passing4OBJ(self.hidden_dim) # self.obj_mps2 = Message_Passing4OBJ(self.hidden_dim) self.get_boxes_encode = Boxes_Encode(64) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### # self.obj_classify_head = nn.Linear(self.pooling_dim, self.num_classes) # self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim//2) # self.post_emb_s.weight = torch.nn.init.xavier_normal(self.post_emb_s.weight, gain=1.0) # self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim//2) # self.post_emb_o.weight = torch.nn.init.xavier_normal(self.post_emb_o.weight, gain=1.0) # self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim//2) # self.merge_obj_high.weight = torch.nn.init.xavier_normal(self.merge_obj_high.weight, gain=1.0) # self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim//2) # self.merge_obj_low.weight = torch.nn.init.xavier_normal(self.merge_obj_low.weight, gain=1.0) # self.rel_compress = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True) # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) # self.freq_gate = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True) # self.freq_gate.weight = torch.nn.init.xavier_normal(self.freq_gate.weight, gain=1.0) self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim) self.post_emb_s.weight = torch.nn.init.xavier_normal( self.post_emb_s.weight, gain=1.0) self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim) self.post_emb_o.weight = torch.nn.init.xavier_normal( self.post_emb_o.weight, gain=1.0) self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim) self.merge_obj_high.weight = torch.nn.init.xavier_normal( self.merge_obj_high.weight, gain=1.0) self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim) self.merge_obj_low.weight = torch.nn.init.xavier_normal( self.merge_obj_low.weight, gain=1.0) self.rel_compress = nn.Linear(self.pooling_dim + 64, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) self.freq_gate = nn.Linear(self.pooling_dim + 64, self.num_rels, bias=True) self.freq_gate.weight = torch.nn.init.xavier_normal( self.freq_gate.weight, gain=1.0) # self.ranking_module = nn.Sequential(nn.Linear(self.pooling_dim + 64, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, 1)) if self.use_bias: self.freq_bias = FrequencyBias() @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) # def fixed_obj_modules(self): # for p in self.detector.parameters(): # p.requires_grad = False # for p in self.obj_embed.parameters(): # p.requires_grad = False # for p in self.pos_embed.parameters(): # p.requires_grad = False # for p in self.context.parameters(): # p.requires_grad = False # for p in self.union_boxes.parameters(): # p.requires_grad = False # for p in self.merge_obj_feats.parameters(): # p.requires_grad = False # for p in self.get_phr_feats.parameters(): # p.requires_grad = False # for p in self.embeddings4lstm.parameters(): # p.requires_grad = False # for p in self.lstm.parameters(): # p.requires_grad = False # for p in self.obj_mps1.parameters(): # p.requires_grad = False # for p in self.roi_fmap_obj.parameters(): # p.requires_grad = False # for p in self.roi_fmap.parameters(): # p.requires_grad = False def save_grad(self, grad): self.gradients.append(grad) def visual_rep(self, features, rois, pair_inds): """ Classify the features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds inds to use when predicting :return: score_pred, a [num_rois, num_classes] array box_pred, a [num_rois, num_classes, 4] array """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def visual_obj(self, features, rois, pair_inds): assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return uboxes def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat( (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def union_pairs(self, im_inds): rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 rel_inds = rel_cands.nonzero() rel_inds = torch.cat((im_inds[rel_inds[:, 0]][:, None].data, rel_inds), -1) return rel_inds def obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): """ Forward pass for detection :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE] :param im_sizes: A numpy array of (h, w, scale) for each image. :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0) :param gt_boxes: Training parameters: :param gt_boxes: [num_gt, 4] GT boxes over the batch. :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will be used to compute the training loss. Each (img_ind, fpn_idx) :return: If train: scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels if test: prob dists, boxes, img inds, maxscores, classes """ result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) # rel_feat = self.relationship_feat.feature_map(x) if result.is_none(): return ValueError("heck") im_inds = result.im_inds - image_offset boxes = result.rm_box_priors if self.training and result.rel_labels is None: assert self.mode == 'sgdet' result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) spt_feats = self.get_boxes_encode(boxes, rel_inds) pair_inds = self.union_pairs(im_inds) if self.hook_for_grad: rel_inds = gt_rels[:, :-1].data if self.hook_for_grad: fmap = result.fmap fmap.register_hook(self.save_grad) else: fmap = result.fmap.detach() rois = torch.cat((im_inds[:, None].float(), boxes), 1) result.obj_fmap = self.obj_feature_map(fmap, rois) # result.obj_dists_head = self.obj_classify_head(obj_fmap_rel) obj_embed = F.softmax(result.rm_obj_dists, dim=1) @ self.obj_embed.weight obj_embed_lstm = F.softmax(result.rm_obj_dists, dim=1) @ self.embeddings4lstm.weight pos_embed = self.pos_embed(Variable(center_size(boxes.data))) obj_pre_rep = torch.cat((result.obj_fmap, obj_embed, pos_embed), 1) obj_feats = self.merge_obj_feats(obj_pre_rep) # obj_feats=self.trans(obj_feats) obj_feats_lstm = torch.cat( (obj_feats, obj_embed_lstm), -1).contiguous().view(1, obj_feats.size(0), -1) # obj_feats = F.relu(obj_feats) phr_ori = self.visual_rep(fmap, rois, pair_inds[:, 1:]) vr_indices = torch.from_numpy( intersect_2d(rel_inds[:, 1:].cpu().numpy(), pair_inds[:, 1:].cpu().numpy()).astype( np.uint8)).cuda().max(-1)[1] vr = phr_ori[vr_indices] phr_feats_high = self.get_phr_feats(phr_ori) obj_feats_lstm_output, (obj_hidden_states, obj_cell_states) = self.lstm(obj_feats_lstm) rm_obj_dists1 = result.rm_obj_dists + self.context.decoder_lin( obj_feats_lstm_output.squeeze()) obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \ phr_feats_high, im_inds, pair_inds) obj_embed_lstm1 = F.softmax(rm_obj_dists1, dim=1) @ self.embeddings4lstm.weight obj_feats_lstm1 = torch.cat((obj_feats_output, obj_embed_lstm1), -1).contiguous().view(1, \ obj_feats_output.size(0), -1) obj_feats_lstm_output, _ = self.lstm( obj_feats_lstm1, (obj_hidden_states, obj_cell_states)) rm_obj_dists2 = rm_obj_dists1 + self.context.decoder_lin( obj_feats_lstm_output.squeeze()) obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \ phr_feats_high, im_inds, pair_inds) # Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists, result.obj_preds = self.context( rm_obj_dists2, obj_feats_output, result.rm_obj_labels if self.training or self.mode == 'predcls' else None, boxes.data, result.boxes_all) obj_dtype = result.obj_fmap.data.type() obj_preds_embeds = torch.index_select(self.ort_embedding, 0, result.obj_preds).type(obj_dtype) tranfered_boxes = torch.stack( (boxes[:, 0] / IM_SCALE, boxes[:, 3] / IM_SCALE, boxes[:, 2] / IM_SCALE, boxes[:, 1] / IM_SCALE, ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])) / (IM_SCALE**2)), -1).type(obj_dtype) obj_features = torch.cat( (result.obj_fmap, obj_preds_embeds, tranfered_boxes), -1) obj_features_merge = self.merge_obj_low( obj_features) + self.merge_obj_high(obj_feats_output) # Split into subject and object representations result.subj_rep = self.post_emb_s(obj_features_merge)[rel_inds[:, 1]] result.obj_rep = self.post_emb_o(obj_features_merge)[rel_inds[:, 2]] prod_rep = result.subj_rep * result.obj_rep # obj_pools = self.visual_obj(result.fmap.detach(), rois, rel_inds[:, 1:]) # rel_pools = self.relationship_feat.union_rel_pooling(rel_feat, rois, rel_inds[:, 1:]) # context_pools = torch.cat([obj_pools, rel_pools], 1) # merge_pool = self.merge_feat(context_pools) # vr = self.roi_fmap(merge_pool) # vr = self.rel_refine(vr) prod_rep = prod_rep * vr if self.use_tanh: prod_rep = F.tanh(prod_rep) prod_rep = torch.cat((prod_rep, spt_feats), -1) freq_gate = self.freq_gate(prod_rep) freq_gate = F.sigmoid(freq_gate) result.rel_dists = self.rel_compress(prod_rep) # result.rank_factor = self.ranking_module(prod_rep).view(-1) if self.use_bias: result.rel_dists = result.rel_dists + freq_gate * self.freq_bias.index_with_labels( torch.stack(( result.obj_preds[rel_inds[:, 1]], result.obj_preds[rel_inds[:, 2]], ), 1)) if self.training: return result twod_inds = arange( result.obj_preds.data) * self.num_classes + result.obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view( result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) # rel_rep = smooth_one_hot(rel_rep) # rank_factor = F.sigmoid(result.rank_factor) return filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep) def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus))) outputs = nn.parallel.parallel_apply( replicas, [batch[i] for i in range(self.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs
class RelModel(nn.Module): """ RELATIONSHIPS """ def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision=limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() # not too large; because in the same img, rel class is mostly 0; if too large, most neg rel is repeated self.neg_num = 1 """ self.embdim = 100 self.obj1_fc= nn.Sequential( nn.BatchNorm1d(4096), nn.ReLU(inplace=True), nn.Linear(4096, self.num_classes * self.embdim, bias=True), nn.BatchNorm1d(self.num_classes * self.embdim), nn.ReLU(inplace=True), ) self.obj2_fc= nn.Sequential( nn.BatchNorm1d(4096), nn.ReLU(inplace=True), nn.Linear(4096, self.num_classes * self.embdim, bias=True), nn.BatchNorm1d(self.num_classes * self.embdim), nn.ReLU(inplace=True), ) self.rel_seq = nn.Sequential( nn.BatchNorm1d(4096), nn.ReLU(inplace=True), nn.Linear(4096, self.num_rels * self.embdim, bias=True), nn.BatchNorm1d(self.num_rels * self.embdim), nn.ReLU(inplace=True), ) #self.new_roi_fmap_obj = load_vgg(pretrained=False).classifier """ @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) def visual_rep(self, features, rois, pair_inds): """ Classify the features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds inds to use when predicting :return: score_pred, a [num_rois, num_classes] array box_pred, a [num_rois, num_classes, 4] array """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)( features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) # vgg.classifier def new_obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)( features, rois) return self.new_roi_fmap_obj(feature_pool.view(rois.size(0), -1)) # vgg.classifier def get_neg_examples(self, rel_labels): """ Given relationship combination (positive examples), return the negative examples. :param rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) :return: neg_rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) """ neg_rel_labels = [] num_im = rel_labels.data[:,0].max()+1 im_inds = rel_labels.data.cpu().numpy()[:,0] rel_type = rel_labels.data.cpu().numpy()[:,3] box_pairs = rel_labels.data.cpu().numpy()[:,:3] for im_ind in range(num_im): pred_ind = np.where(im_inds == im_ind)[0] rel_type_i = rel_type[pred_ind] rel_labels_i = box_pairs[pred_ind][:,None,:] row_num = rel_labels_i.shape[0] rel_labels_i = torch.LongTensor(rel_labels_i).expand_as(torch.Tensor(row_num, self.neg_num, 3)) neg_pairs_i = rel_labels_i.contiguous().view(-1, 3).cpu().numpy() neg_rel_type_i = np.zeros(self.neg_num) for k in range(rel_type_i.shape[0]): neg_rel_type_k = np.delete(rel_type_i, np.where(rel_type_i == rel_type_i[k])[0]) # delete same rel class #assert neg_rel_type_k.shape[0] != 0 if neg_rel_type_k.shape[0] != 0: neg_rel_type_k = np.random.choice(neg_rel_type_k, size=self.neg_num, replace=True) neg_rel_type_i = np.concatenate((neg_rel_type_i,neg_rel_type_k),axis=0) else: orig_cls = np.arange(self.num_rels) cls_pool = np.delete(orig_cls, np.where( orig_cls == rel_type_i[k] )[0]) neg_rel_type_k = np.random.choice(cls_pool, size=self.neg_num, replace=False) neg_rel_type_i = np.concatenate((neg_rel_type_i,neg_rel_type_k),axis=0) neg_rel_type_i = np.delete(neg_rel_type_i, np.arange(self.neg_num)) # delete the first few rows assert neg_pairs_i.shape[0] == neg_rel_type_i.shape[0] neg_rel_labels.append(np.column_stack((neg_pairs_i,neg_rel_type_i))) neg_rel_labels = torch.LongTensor(np.concatenate(np.array(neg_rel_labels), 0)) neg_rel_labels = neg_rel_labels.cuda(rel_labels.get_device(), async=True) return neg_rel_labels def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): """ Forward pass for detection :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE] :param im_sizes: A numpy array of (h, w, scale) for each image. :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0) :param gt_boxes: Training parameters: :param gt_boxes: [num_gt, 4] GT boxes over the batch. :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will be used to compute the training loss. Each (img_ind, fpn_idx) :return: If train: scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels if test: prob dists, boxes, img inds, maxscores, classes """ # Detector result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) if result.is_none(): return ValueError("heck") im_inds = result.im_inds - image_offset # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach() boxes = result.rm_box_priors.detach() if self.training and result.rel_labels is None: assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None # rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) rel_labels_neg = self.get_neg_examples(result.rel_labels) rel_inds_neg = rel_labels_neg[:,:3] #torch.cat((result.rel_labels[:,0].contiguous().view(236,1),result.rm_obj_labels[result.rel_labels[:,1]].view(236,1),result.rm_obj_labels[result.rel_labels[:,2]].view(236,1),result.rel_labels[:,3].contiguous().view(236,1)),-1) rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) #[275,3], [im_inds, box1_inds, box2_inds] # rois: [#boxes, 5] rois = torch.cat((im_inds[:, None].float(), boxes), 1) # result.rm_obj_fmap: [384, 4096] #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois.detach()) # detach: prevent backforward flowing # BiLSTM result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context( result.rm_obj_fmap, # has been detached above # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists.detach(), # .detach:Returns a new Variable, detached from the current graph im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None, boxes.data, result.boxes_all.detach() if self.mode == 'sgdet' else result.boxes_all) # Post Processing # nl_egde <= 0 if edge_ctx is None: edge_rep = self.post_emb(result.rm_obj_preds) # nl_edge > 0 else: edge_rep = self.post_lstm(edge_ctx) # [384, 4096*2] # Split into subject and object representations edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim) #[384,2,4096] subj_rep = edge_rep[:, 0] # [384,4096] obj_rep = edge_rep[:, 1] # [384,4096] prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]] # prod_rep, rel_inds: [275,4096], [275,3] if self.use_vision: # True when sgdet # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096] vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:]) if self.limit_vision: # False when sgdet # exact value TBD prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) else: prod_rep = prod_rep * vr # [275,4096] if self.training: vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:]) prod_rep_neg = subj_rep[rel_inds_neg[:, 1]].detach() * obj_rep[rel_inds_neg[:, 2]].detach() * vr_neg rel_dists_neg = self.rel_compress(prod_rep_neg) if self.use_tanh: # False when sgdet prod_rep = F.tanh(prod_rep) result.rel_dists = self.rel_compress(prod_rep) # [275,51] if self.use_bias: # True when sgdet result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack(( result.rm_obj_preds[rel_inds[:, 1]], result.rm_obj_preds[rel_inds[:, 2]], ), 1)) if self.training: judge = result.rel_labels.data[:,3] != 0 if judge.sum() != 0: # gt_rel exit in rel_inds select_rel_inds = torch.arange(rel_inds.size(0)).view(-1,1).long().cuda()[result.rel_labels.data[:,3] != 0] com_rel_inds = rel_inds[select_rel_inds] twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists.detach(), dim=1).view(-1)[twod_inds] # only 1/4 of 384 obj_dists will be updated; because only 1/4 objs's labels are not 0 # positive overall score obj_scores0 = result.obj_scores[com_rel_inds[:,1]] obj_scores1 = result.obj_scores[com_rel_inds[:,2]] rel_rep = F.softmax(result.rel_dists[select_rel_inds], dim=1) # result.rel_dists has grad _, pred_classes_argmax = rel_rep.data[:,:].max(1) # all classes max_rel_score = rel_rep.gather(1, Variable(pred_classes_argmax.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward score_list = torch.cat((com_rel_inds[:,0].float().contiguous().view(-1,1), obj_scores0.data.view(-1,1), obj_scores1.data.view(-1,1), max_rel_score.data.view(-1,1)), 1) prob_score = max_rel_score * obj_scores0.detach() * obj_scores1.detach() #pos_prob[:,1][result.rel_labels.data[:,3] == 0] = 0 # treat most rel_labels as neg because their rel cls is 0 "unknown" # negative overall score obj_scores0_neg = result.obj_scores[rel_inds_neg[:,1]] obj_scores1_neg = result.obj_scores[rel_inds_neg[:,2]] rel_rep_neg = F.softmax(rel_dists_neg, dim=1) # rel_dists_neg has grad _, pred_classes_argmax_neg = rel_rep_neg.data[:,:].max(1) # all classes max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward score_list_neg = torch.cat((rel_inds_neg[:,0].float().contiguous().view(-1,1), obj_scores0_neg.data.view(-1,1), obj_scores1_neg.data.view(-1,1), max_rel_score_neg.data.view(-1,1)), 1) prob_score_neg = max_rel_score_neg * obj_scores0_neg.detach() * obj_scores1_neg.detach() # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg all_rel_inds = torch.cat((result.rel_labels.data[select_rel_inds], rel_labels_neg), 0) # [#pos_inds+#neg_inds, 4] flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0) score_list_all = torch.cat((score_list,score_list_neg), 0) all_prob = torch.cat((prob_score,prob_score_neg), 0) # Variable, [#pos_inds+#neg_inds, 1] _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True) sorted_rel_inds = all_rel_inds[sort_prob_inds] sorted_flag = flag[sort_prob_inds].squeeze() # can be used to check distribution of pos and neg sorted_score_list_all = score_list_all[sort_prob_inds] sorted_all_prob = all_prob[sort_prob_inds] # Variable # positive triplet and score list pos_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4) pos_trips = torch.cat((pos_sorted_inds[:,0].contiguous().view(-1,1), result.rm_obj_labels.data.view(-1,1)[pos_sorted_inds[:,1]], result.rm_obj_labels.data.view(-1,1)[pos_sorted_inds[:,2]], pos_sorted_inds[:,3].contiguous().view(-1,1)), 1) pos_score_list = sorted_score_list_all.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4) pos_exp = sorted_all_prob[sorted_flag == 1] # Variable # negative triplet and score list neg_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4) neg_trips = torch.cat((neg_sorted_inds[:,0].contiguous().view(-1,1), result.rm_obj_labels.data.view(-1,1)[neg_sorted_inds[:,1]], result.rm_obj_labels.data.view(-1,1)[neg_sorted_inds[:,2]], neg_sorted_inds[:,3].contiguous().view(-1,1)), 1) neg_score_list = sorted_score_list_all.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4) neg_exp = sorted_all_prob[sorted_flag == 0] # Variable int_part = neg_exp.size(0) // pos_exp.size(0) decimal_part = neg_exp.size(0) % pos_exp.size(0) int_inds = torch.arange(pos_exp.size(0))[:,None].expand_as(torch.Tensor(pos_exp.size(0), int_part)).contiguous().view(-1) int_part_inds = (int(pos_exp.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative if decimal_part == 0: expand_inds = int_part_inds else: expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0) result.pos = pos_exp[expand_inds] result.neg = neg_exp result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda()) # some variables .register_hook(extract_grad) return result else: # no gt_rel in rel_inds print("no gt_rel in rel_inds!!!!!!!!!!!!!!!!!!!!!!!!!!!!") twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists.detach(), dim=1).view(-1)[twod_inds] # positive overall score obj_scores0 = result.obj_scores[rel_inds[:,1]] obj_scores1 = result.obj_scores[rel_inds[:,2]] rel_rep = F.softmax(result.rel_dists, dim=1) # [275, 51] _, pred_classes_argmax = rel_rep.data[:,:].max(1) # all classes max_rel_score = rel_rep.gather(1, Variable(pred_classes_argmax.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward prob_score = max_rel_score * obj_scores0.detach() * obj_scores1.detach() #pos_prob[:,1][result.rel_labels.data[:,3] == 0] = 0 # treat most rel_labels as neg because their rel cls is 0 "unknown" # negative overall score obj_scores0_neg = result.obj_scores[rel_inds_neg[:,1]] obj_scores1_neg = result.obj_scores[rel_inds_neg[:,2]] rel_rep_neg = F.softmax(rel_dists_neg, dim=1) _, pred_classes_argmax_neg = rel_rep_neg.data[:,:].max(1) # all classes max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward prob_score_neg = max_rel_score_neg * obj_scores0_neg.detach() * obj_scores1_neg.detach() # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg all_rel_inds = torch.cat((result.rel_labels.data, rel_labels_neg), 0) # [#pos_inds+#neg_inds, 4] flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0) all_prob = torch.cat((prob_score,prob_score_neg), 0) # Variable, [#pos_inds+#neg_inds, 1] _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True) sorted_rel_inds = all_rel_inds[sort_prob_inds] sorted_flag = flag[sort_prob_inds].squeeze() # can be used to check distribution of pos and neg sorted_all_prob = all_prob[sort_prob_inds] # Variable pos_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4) neg_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4) pos_exp = sorted_all_prob[sorted_flag == 1] # Variable neg_exp = sorted_all_prob[sorted_flag == 0] # Variable int_part = neg_exp.size(0) // pos_exp.size(0) decimal_part = neg_exp.size(0) % pos_exp.size(0) int_inds = torch.arange(pos_exp.data.size(0))[:,None].expand_as(torch.Tensor(pos_exp.data.size(0), int_part)).contiguous().view(-1) int_part_inds = (int(pos_exp.data.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative if decimal_part == 0: expand_inds = int_part_inds else: expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0) result.pos = pos_exp[expand_inds] result.neg = neg_exp result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda()) return result ###################### Testing ########################### # extract corrsponding scores according to the box's preds twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # [384] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) # [275, 51] # sort product of obj1 * obj2 * rel return filter_dets(bboxes, result.obj_scores, result.rm_obj_preds, rel_inds[:, 1:], rel_rep) def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus))) outputs = nn.parallel.parallel_apply(replicas, [batch[i] for i in range(self.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs
def __init__(self, classes, rel_classes, embed_dim, obj_dim, inputs_dim, hidden_dim, pooling_dim, recurrent_dropout_probability=0.2, use_highway=True, use_input_projection_bias=True, use_vision=True, use_bias=True, use_tanh=True, limit_vision=True, sl_pretrain=False, num_iter=-1): """ Initializes the RNN :param embed_dim: Dimension of the embeddings :param encoder_hidden_dim: Hidden dim of the encoder, for attention purposes :param hidden_dim: Hidden dim of the decoder :param vocab_size: Number of words in the vocab :param bos_token: To use during decoding (non teacher forcing mode)) :param bos: beginning of sentence token :param unk: unknown token (not used) """ super(DecoderRNN, self).__init__() self.rel_embedding_dim = 100 self.classes = classes self.rel_classes = rel_classes embed_vecs = obj_edge_vectors(['start'] + self.classes, wv_dim=100) self.obj_embed = nn.Embedding(len(self.classes), embed_dim) self.obj_embed.weight.data = embed_vecs embed_rels = obj_edge_vectors(self.rel_classes, wv_dim=self.rel_embedding_dim) self.rel_embed = nn.Embedding(len(self.rel_classes), self.rel_embedding_dim) self.rel_embed.weight.data = embed_rels self.embed_dim = embed_dim self.obj_dim = obj_dim self.hidden_size = hidden_dim self.inputs_dim = inputs_dim self.pooling_dim = pooling_dim self.nms_thresh = 0.3 self.use_vision = use_vision self.use_bias = use_bias self.use_tanh = use_tanh self.limit_vision = limit_vision self.sl_pretrain = sl_pretrain self.num_iter = num_iter self.recurrent_dropout_probability = recurrent_dropout_probability self.use_highway = use_highway # We do the projections for all the gates all at once, so if we are # using highway layers, we need some extra projections, which is # why the sizes of the Linear layers change here depending on this flag. if use_highway: self.input_linearity = torch.nn.Linear( self.input_size, 6 * self.hidden_size, bias=use_input_projection_bias) self.state_linearity = torch.nn.Linear(self.hidden_size, 5 * self.hidden_size, bias=True) else: self.input_linearity = torch.nn.Linear( self.input_size, 4 * self.hidden_size, bias=use_input_projection_bias) self.state_linearity = torch.nn.Linear(self.hidden_size, 4 * self.hidden_size, bias=True) # self.obj_in_lin = torch.nn.Linear(self.rel_embedding_dim, self.rel_embedding_dim, bias=True) self.out = nn.Linear(self.hidden_size, len(self.classes)) self.reset_parameters() # For relation predication embed_vecs2 = obj_edge_vectors(self.classes, wv_dim=embed_dim) self.obj_embed2 = nn.Embedding(self.num_classes, embed_dim) self.obj_embed2.weight.data = embed_vecs2.clone() # self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) self.post_lstm = nn.Linear(self.obj_dim + 2 * self.embed_dim + 128, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_size) ) ######## there may need more consideration self.post_lstm.bias.data.zero_() self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() # simple relation model from dataloaders.visual_genome import VG from lib.get_dataset_counts import get_counts, box_filter fg_matrix, bg_matrix = get_counts(train_data=VG.splits( num_val_im=5000, filter_non_overlap=True, filter_duplicate_rels=True, use_proposals=False)[0], must_overlap=True) prob_matrix = fg_matrix.astype(np.float32) prob_matrix[:, :, 0] = bg_matrix # TRYING SOMETHING NEW. prob_matrix[:, :, 0] += 1 prob_matrix /= np.sum(prob_matrix, 2)[:, :, None] # prob_matrix /= float(fg_matrix.max()) prob_matrix[:, :, 0] = 0 # Zero out BG self.prob_matrix = prob_matrix
class RelModel(RelModelBase): """ Depth-Fusion relation detection model """ # -- Different components' FC layer size FC_SIZE_VISUAL = 512 FC_SIZE_CLASS = 64 FC_SIZE_LOC = 20 FC_SIZE_DEPTH = 4096 LOC_INPUT_SIZE = 8 def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=False, require_overlap_det=True, embed_dim=200, hidden_dim=4096, use_resnet=False, thresh=0.01, use_proposals=False, use_bias=True, limit_vision=True, depth_model=None, pretrained_depth=False, active_features=None, frozen_features=None, use_embed=False, **kwargs): """ :param classes: object classes :param rel_classes: relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: enable the contribution of union of bounding boxes :param require_overlap_det: whether two objects must intersect :param embed_dim: word2vec embeddings dimension :param hidden_dim: dimension of the fusion hidden layer :param use_resnet: use resnet as faster-rcnn's backbone :param thresh: faster-rcnn related threshold (Threshold for calling it a good box) :param use_proposals: whether to use region proposal candidates :param use_bias: enable frequency bias :param limit_vision: use truncated version of UoBB features :param depth_model: provided architecture for depth feature extraction :param pretrained_depth: whether the depth feature extractor should be initialized with ImageNet weights :param active_features: what set of features should be enabled (e.g. 'vdl' : visual, depth, and location features) :param frozen_features: what set of features should be frozen (e.g. 'd' : depth) :param use_embed: use word2vec embeddings """ RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus, require_overlap_det, active_features, frozen_features) self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.use_vision = use_vision self.use_bias = use_bias self.limit_vision = limit_vision # -- Store depth related parameters assert depth_model in DEPTH_MODELS self.depth_model = depth_model self.pretrained_depth = pretrained_depth self.depth_pooling_dim = DEPTH_DIMS[self.depth_model] self.use_embed = use_embed self.detector = nn.Module() features_size = 0 # -- Check whether ResNet is selected as faster-rcnn's backbone if use_resnet: raise ValueError( "The current model does not support ResNet as the Faster-RCNN's backbone." ) """ *** DIFFERENT COMPONENTS OF THE PROPOSED ARCHITECTURE *** This is the part where the different components of the proposed relation detection architecture are defined. In the case of RGB images, we have class probability distribution features, visual features, and the location ones. If we are considering depth images as well, we augment depth features too. """ # -- Visual features if self.has_visual: # -- Define faster R-CNN network and it's related feature extractors self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.roi_fmap_obj = load_vgg(pretrained=False).classifier # -- Define union features if self.use_vision: # -- UoBB pooling module self.union_boxes = UnionBoxesAndFeats( pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) # -- UoBB feature extractor roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=self.hidden_dim == 4096, pretrained=False).classifier, ] if self.hidden_dim != 4096: roi_fmap.append(nn.Linear(4096, self.hidden_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) # -- Define visual features hidden layer self.visual_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(self.obj_dim * 2, self.FC_SIZE_VISUAL)), nn.ReLU(inplace=True), nn.Dropout(0.8) ]) self.visual_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_VISUAL # -- Location features if self.has_loc: # -- Define location features hidden layer self.location_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(self.LOC_INPUT_SIZE, self.FC_SIZE_LOC)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) self.location_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_LOC # -- Class features if self.has_class: if self.use_embed: # -- Define class embeddings embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim) self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim) self.obj_embed.weight.data = embed_vecs.clone() classme_input_dim = self.embed_dim if self.use_embed else self.num_classes # -- Define Class features hidden layer self.classme_hlayer = nn.Sequential(*[ xavier_init( nn.Linear(classme_input_dim * 2, self.FC_SIZE_CLASS)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) self.classme_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_CLASS # -- Depth features if self.has_depth: # -- Initialize depth backbone self.depth_backbone = DepthCNN(depth_model=self.depth_model, pretrained=self.pretrained_depth) # -- Create a relation head which is used to carry on the feature extraction # from RoIs of depth features self.depth_rel_head = self.depth_backbone.get_classifier() # -- Define depth features hidden layer self.depth_rel_hlayer = nn.Sequential(*[ xavier_init( nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)), nn.ReLU(inplace=True), nn.Dropout(0.6), ]) self.depth_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_DEPTH # -- Initialize frequency bias if needed if self.use_bias: self.freq_bias = FrequencyBias() # -- *** Fusion layer *** -- # -- A hidden layer for concatenated features (fusion features) self.fusion_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(features_size, self.hidden_dim)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) # -- Final FC layer which predicts the relations self.rel_out = xavier_init( nn.Linear(self.hidden_dim, self.num_rels, bias=True)) # -- Freeze the user specified features if self.frz_visual: self.freeze_module(self.detector) self.freeze_module(self.roi_fmap_obj) self.freeze_module(self.visual_hlayer) if self.use_vision: self.freeze_module(self.roi_fmap) self.freeze_module(self.union_boxes.conv) if self.frz_class: self.freeze_module(self.classme_hlayer) if self.frz_loc: self.freeze_module(self.location_hlayer) if self.frz_depth: self.freeze_module(self.depth_backbone) self.freeze_module(self.depth_rel_head) self.freeze_module(self.depth_rel_hlayer) def get_roi_features(self, features, rois): """ Gets ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlign((self.pooling_size, self.pooling_size), spatial_scale=1 / 16, sampling_ratio=-1)(features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def get_union_features(self, features, rois, pair_inds): """ Gets UoBB features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds: inds to use when predicting :return: UoBB features """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def get_roi_features_depth(self, features, rois): """ Gets ROI features (depth) :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlign((self.pooling_size, self.pooling_size), spatial_scale=1 / 16, sampling_ratio=-1)(features, rois) # -- Flatten the layer if the model is not RESNET/SQZNET if self.depth_model not in ('resnet18', 'resnet50', 'sqznet'): feature_pool = feature_pool.view(rois.size(0), -1) return self.depth_rel_head(feature_pool) @staticmethod def get_loc_features(boxes, subj_inds, obj_inds): """ Calculate the scale-invariant location feature :param boxes: ground-truth/detected boxes :param subj_inds: subject indices :param obj_inds: object indices :return: location_feature """ boxes_centered = center_size(boxes.data) # -- Determine box's center and size (subj's box) center_subj = boxes_centered[subj_inds][:, 0:2] size_subj = boxes_centered[subj_inds][:, 2:4] # -- Determine box's center and size (obj's box) center_obj = boxes_centered[obj_inds][:, 0:2] size_obj = boxes_centered[obj_inds][:, 2:4] # -- Calculate the scale-invariant location features of the subject t_coord_subj = (center_subj - center_obj) / size_obj t_size_subj = torch.log(size_subj / size_obj) # -- Calculate the scale-invariant location features of the object t_coord_obj = (center_obj - center_subj) / size_subj t_size_obj = torch.log(size_obj / size_subj) # -- Put everything together location_feature = Variable( torch.cat((t_coord_subj, t_size_subj, t_coord_obj, t_size_obj), 1)) return location_feature def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False, depth_imgs=None): """ Forward pass for relation detection :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE] :param im_sizes: a numpy array of (h, w, scale) for each image. :param image_offset: offset onto what image we're on for MGPU training (if single GPU this is 0) :param gt_boxes: [num_gt, 4] GT boxes over the batch. :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) :param gt_rels: [] gt relations :param proposals: region proposals retrieved from file :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will be used to compute the training loss. Each (img_ind, fpn_idx) :param return_fmap: if the object detector must return the extracted feature maps :param depth_imgs: depth images [batch_size, 1, IM_SIZE, IM_SIZE] :return: If train: scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels if test: prob dists, boxes, img inds, maxscores, classes """ if self.has_visual: # -- Feed forward the rgb images to Faster-RCNN result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) else: # -- Get prior `result` object (instead of calling faster-rcnn's detector) result = self.get_prior_results(image_offset, gt_boxes, gt_classes, gt_rels) # -- Get RoI and relations rois, rel_inds = self.get_rois_and_rels(result, image_offset, gt_boxes, gt_classes, gt_rels) boxes = result.rm_box_priors # -- Determine subject and object indices subj_inds = rel_inds[:, 1] obj_inds = rel_inds[:, 2] # -- Prepare object predictions vector (PredCLS) # replace with ground truth labels result.obj_preds = result.rm_obj_labels # replace with one-hot distribution of ground truth labels result.rm_obj_dists = F.one_hot(result.rm_obj_labels.data, self.num_classes).float() obj_cls = result.rm_obj_dists result.rm_obj_dists = result.rm_obj_dists * 1000 + ( 1 - result.rm_obj_dists) * (-1000) rel_features = [] # -- Extract RGB features if self.has_visual: # Feed the extracted features from first conv layers to the last 'classifier' layers (VGG) # Here, only the last 3 layers of VGG are being trained. Everything else (in self.detector) # is frozen. result.obj_fmap = self.get_roi_features(result.fmap.detach(), rois) # -- Create a pairwise relation vector out of visual features rel_visual = torch.cat( (result.obj_fmap[subj_inds], result.obj_fmap[obj_inds]), 1) rel_visual_fc = self.visual_hlayer(rel_visual) rel_visual_scale = self.visual_scale(rel_visual_fc) rel_features.append(rel_visual_scale) # -- Extract Location features if self.has_loc: # -- Create a pairwise relation vector out of location features rel_location = self.get_loc_features(boxes, subj_inds, obj_inds) rel_location_fc = self.location_hlayer(rel_location) rel_location_scale = self.location_scale(rel_location_fc) rel_features.append(rel_location_scale) # -- Extract Class features if self.has_class: if self.use_embed: obj_cls = obj_cls @ self.obj_embed.weight # -- Create a pairwise relation vector out of class features rel_classme = torch.cat((obj_cls[subj_inds], obj_cls[obj_inds]), 1) rel_classme_fc = self.classme_hlayer(rel_classme) rel_classme_scale = self.classme_scale(rel_classme_fc) rel_features.append(rel_classme_scale) # -- Extract Depth features if self.has_depth: # -- Extract features from depth backbone depth_features = self.depth_backbone(depth_imgs) depth_rois_features = self.get_roi_features_depth( depth_features, rois) # -- Create a pairwise relation vector out of location features rel_depth = torch.cat((depth_rois_features[subj_inds], depth_rois_features[obj_inds]), 1) rel_depth_fc = self.depth_rel_hlayer(rel_depth) rel_depth_scale = self.depth_scale(rel_depth_fc) rel_features.append(rel_depth_scale) # -- Create concatenated feature vector rel_fusion = torch.cat(rel_features, 1) # -- Extract relation embeddings (penultimate layer) rel_embeddings = self.fusion_hlayer(rel_fusion) # -- Mix relation embeddings with UoBB features if self.has_visual and self.use_vision: uobb_features = self.get_union_features(result.fmap.detach(), rois, rel_inds[:, 1:]) if self.limit_vision: # exact value TBD uobb_limit = int(self.hidden_dim / 2) rel_embeddings = torch.cat((rel_embeddings[:, :uobb_limit] * uobb_features[:, :uobb_limit], rel_embeddings[:, uobb_limit:]), 1) else: rel_embeddings = rel_embeddings * uobb_features # -- Predict relation distances result.rel_dists = self.rel_out(rel_embeddings) # -- Frequency bias if self.use_bias: result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels( torch.stack(( result.obj_preds[rel_inds[:, 1]], result.obj_preds[rel_inds[:, 2]], ), 1)) if self.training: return result # --- *** END OF ARCHITECTURE *** ---# twod_inds = arange( result.obj_preds.data) * self.num_classes + result.obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view( result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) # Filtering: Subject_Score * Pred_score * Obj_score, sorted and ranked return filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep)
class RelModel(nn.Module): """ RELATIONSHIPS """ def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision=limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' # print('REL MODEL CONSTRUCTOR: 1') self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) # print('REL MODEL CONSTRUCTOR: 2') self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) # print('REL MODEL CONSTRUCTOR: 3') if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier # print('REL MODEL CONSTRUCTOR: 4') ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() # print('REL MODEL CONSTRUCTOR: 5') if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() # print('REL MODEL CONSTRUCTOR: over') @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) def visual_rep(self, features, rois, pair_inds): """ Classify the features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds inds to use when predicting :return: score_pred, a [num_rois, num_classes] array box_pred, a [num_rois, num_classes, 4] array """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)( features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): """ Forward pass for detection :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE] :param im_sizes: A numpy array of (h, w, scale) for each image. :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0) :param gt_boxes: Training parameters: :param gt_boxes: [num_gt, 4] GT boxes over the batch. :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will be used to compute the training loss. Each (img_ind, fpn_idx) :return: If train: scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels if test: prob dists, boxes, img inds, maxscores, classes """ result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) if result.is_none(): return ValueError("heck") im_inds = result.im_inds - image_offset boxes = result.rm_box_priors if self.training and result.rel_labels is None: assert self.mode == 'sgdet' result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) rois = torch.cat((im_inds[:, None].float(), boxes), 1) result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists, result.obj_preds, edge_ctx = self.context( result.obj_fmap, result.rm_obj_dists.detach(), im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None, boxes.data, result.boxes_all) if edge_ctx is None: edge_rep = self.post_emb(result.obj_preds) else: edge_rep = self.post_lstm(edge_ctx) # Split into subject and object representations edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim) subj_rep = edge_rep[:, 0] obj_rep = edge_rep[:, 1] prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]] if self.use_vision: vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:]) if self.limit_vision: # exact value TBD prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) else: prod_rep = prod_rep * vr if self.use_tanh: prod_rep = F.tanh(prod_rep) result.rel_dists = self.rel_compress(prod_rep) if self.use_bias: result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack(( result.obj_preds[rel_inds[:, 1]], result.obj_preds[rel_inds[:, 2]], ), 1)) if self.training: return result twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) return filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep) def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus))) outputs = nn.parallel.parallel_apply(replicas, [batch[i] for i in range(self.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, gnn=True, reachability=False, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.reachability = reachability self.gnn = gnn self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.global_embedding = EmbeddingImagenet(4096) self.global_logist = nn.Linear(4096, 151, bias=True) # CosineLinear(4096,150)# self.global_logist.weight = torch.nn.init.xavier_normal( self.global_logist.weight, gain=1.0) self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim + 256) self.meta_classify = MetaEmbedding_Classifier( feat_dim=self.pooling_dim + 256, num_classes=self.num_rels) # self.global_rel_logist = nn.Linear(4096, 50 , bias=True) # self.global_rel_logist.weight = torch.nn.init.xavier_normal(self.global_rel_logist.weight, gain=1.0) # self.global_logist = CosineLinear(4096,150) self.global_sub_additive = nn.Linear(4096, 1, bias=True) self.global_obj_additive = nn.Linear(4096, 1, bias=True) self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) self.edge_coordinate_embedding = nn.Sequential(*[ nn.BatchNorm1d(5, momentum=BATCHNORM_MOMENTUM / 10.0), nn.Linear(5, 256), nn.ReLU(inplace=True), nn.Dropout(0.1), ]) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(4096 + 256, 51, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) self.node_transform = nn.Linear(4096, 256, bias=True) self.edge_transform = nn.Linear(4096, 256, bias=True) # self.rel_compress = CosineLinear(self.pooling_dim+256, self.num_rels) # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() if self.gnn: self.graph_network_node = GraphNetwork(4096) self.graph_network_edge = GraphNetwork() if self.training: self.graph_network_node.train() self.graph_network_edge.train() else: self.graph_network_node.eval() self.graph_network_edge.eval() self.edge_sim_network = nn.Linear(4096, 1, bias=True) self.metric_net = MetricLearning()
class RelModel(nn.Module): """ RELATIONSHIPS """ def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, model_path='', reachability=False, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, init_center=False, limit_vision=True): super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.init_center = init_center self.pooling_size = 7 self.model_path = model_path self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.centroids = None self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.global_embedding = EmbeddingImagenet(4096) self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim) self.meta_classify = MetaEmbedding_Classifier( feat_dim=self.pooling_dim, num_classes=self.num_rels) self.disc_center_g = DiscCentroidsLoss(self.num_classes, self.pooling_dim) self.meta_classify_g = MetaEmbedding_Classifier( feat_dim=self.pooling_dim, num_classes=self.num_classes) self.global_sub_additive = nn.Linear(4096, 1, bias=True) self.global_obj_additive = nn.Linear(4096, 1, bias=True) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() self.global_logist = nn.Linear(self.pooling_dim, self.num_classes, bias=True) # CosineLinear(4096,150)# self.global_logist.weight = torch.nn.init.xavier_normal( self.global_logist.weight, gain=1.0) self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() self.class_num = torch.zeros(len(self.classes)) self.centroids = torch.zeros(len(self.classes), self.pooling_dim).cuda() @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) def visual_rep(self, features, rois, pair_inds): assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat( (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def _to_one_hot(self, y, n_dims, dtype=torch.cuda.FloatTensor): scatter_dim = len(y.size()) y_tensor = y.type(torch.cuda.LongTensor).view(*y.size(), -1) zeros = torch.zeros(*y.size(), n_dims).type(dtype) return zeros.scatter(scatter_dim, y_tensor, 1) def obj_feature_map(self, features, rois): feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def center_calulate(self, feature, labels): for idx, i in enumerate(labels): self.centroids[i] += feature[idx] self.class_num[i] += 1 def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) if result.is_none(): return ValueError("heck") im_inds = result.im_inds - image_offset boxes = result.rm_box_priors if self.training and result.rel_labels is None: assert self.mode == 'sgdet' result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) rois = torch.cat((im_inds[:, None].float(), boxes), 1) result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists, result.obj_preds, node_rep0 = self.context( result.obj_fmap, result.rm_obj_dists.detach(), im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None, boxes.data, result.boxes_all) edge_rep = node_rep0.repeat(1, 2) edge_rep = edge_rep.view(edge_rep.size(0), 2, -1) global_feature = self.global_embedding(result.fmap.detach()) result.global_dists = self.global_logist(global_feature) one_hot_multi = torch.zeros( (result.global_dists.shape[0], self.num_classes)) one_hot_multi[im_inds, result.rm_obj_labels] = 1.0 result.multi_hot = one_hot_multi.float().cuda() subj_global_additive_attention = F.relu( self.global_sub_additive(edge_rep[:, 0] + global_feature[im_inds])) obj_global_additive_attention = F.relu( self.global_obj_additive(edge_rep[:, 1] + global_feature[im_inds])) subj_rep = edge_rep[:, 0] + subj_global_additive_attention * global_feature[ im_inds] obj_rep = edge_rep[:, 1] + obj_global_additive_attention * global_feature[ im_inds] if self.training: self.centroids = self.disc_center.centroids.data # if edge_ctx is None: # edge_rep = self.post_emb(result.obj_preds) # else: # edge_rep = self.post_lstm(edge_ctx) # Split into subject and object representations # edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim) # # subj_rep = edge_rep[:, 0] # obj_rep = edge_rep[:, 1] prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]] vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:]) prod_rep = prod_rep * vr prod_rep = F.tanh(prod_rep) logits, self.direct_memory_feature = self.meta_classify( prod_rep, self.centroids) # result.rel_dists = self.rel_compress(prod_rep) result.rel_dists = logits result.rel_dists2 = self.direct_memory_feature[-1] # result.hallucinate_logits = self.direct_memory_feature[-1] if self.training: result.center_loss = self.disc_center( prod_rep, result.rel_labels[:, -1]) * 0.01 if self.use_bias: result.rel_dists = result.rel_dists + 1.0 * self.freq_bias.index_with_labels( torch.stack(( result.obj_preds[rel_inds[:, 1]], result.obj_preds[rel_inds[:, 2]], ), 1)) if self.training: return result twod_inds = arange( result.obj_preds.data) * self.num_classes + result.obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view( result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) return filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep) def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus))) outputs = nn.parallel.parallel_apply( replicas, [batch[i] for i in range(self.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs
class EndCell(nn.Module): def __init__(self, classes, num_rels, mode='sgdet', embed_dim=200, pooling_dim=4096, use_bias=True): super(EndCell, self).__init__() self.classes = classes self.num_rels = num_rels assert mode in MODES self.embed_dim = embed_dim self.pooling_dim = pooling_dim self.use_bias = use_bias self.mode = mode self.ort_embedding = torch.autograd.Variable( get_ort_embeds(self.num_classes, self.embed_dim).cuda()) self.context = LC(classes=self.classes, mode=self.mode, embed_dim=self.embed_dim, obj_dim=self.pooling_dim) self.union_boxes = UnionBoxesAndFeats(pooling_size=7, stride=16, dim=512) self.pooling_size = 7 roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier self.post_lstm = nn.Linear(self.pooling_dim + self.embed_dim + 5, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.pooling_dim)) self.post_lstm.bias.data.zero_() self.post_emb = nn.Linear(self.pooling_dim + self.embed_dim + 5, self.pooling_dim * 2) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() @property def num_classes(self): return len(self.classes) def visual_rep(self, features, rois, pair_inds): """ Classify the features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds inds to use when predicting :return: score_pred, a [num_rois, num_classes] array box_pred, a [num_rois, num_classes, 4] array """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def visual_obj(self, features, rois, pair_inds): assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return uboxes def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat( (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def forward(self, last_outputs, obj_dists, rel_inds, im_inds, rois, boxes): twod_inds = arange(last_outputs.obj_preds.data ) * self.num_classes + last_outputs.obj_preds.data obj_scores = F.softmax(last_outputs.rm_obj_dists, dim=1).view(-1)[twod_inds] rel_rep, _ = F.softmax(last_outputs.rel_dists, dim=1)[:, 1:].max(1) rel_scores_argmaxed = rel_rep * obj_scores[ rel_inds[:, 0]] * obj_scores[rel_inds[:, 1]] _, rel_scores_idx = torch.sort(rel_scores_argmaxed.view(-1), dim=0, descending=True) rel_scores_idx = rel_scores_idx[:100] filtered_rel_inds = rel_inds[rel_scores_idx.data] obj_fmap = self.obj_feature_map(last_outputs.fmap.detach(), rois) rm_obj_dists, obj_preds = self.context( obj_fmap, obj_dists.detach(), im_inds, last_outputs.rm_obj_labels if self.mode == 'predcls' else None, boxes.data, last_outputs.boxes_all) obj_dtype = obj_fmap.data.type() obj_preds_embeds = torch.index_select(self.ort_embedding, 0, obj_preds).type(obj_dtype) transfered_boxes = torch.stack( (boxes[:, 0] / IM_SCALE, boxes[:, 3] / IM_SCALE, boxes[:, 2] / IM_SCALE, boxes[:, 1] / IM_SCALE, ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])) / (IM_SCALE**2)), -1).type(obj_dtype) obj_features = torch.cat( (obj_fmap, obj_preds_embeds, transfered_boxes), -1) edge_rep = self.post_emb(obj_features) edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim) subj_rep = edge_rep[:, 0][filtered_rel_inds[:, 1]] obj_rep = edge_rep[:, 1][filtered_rel_inds[:, 2]] prod_rep = subj_rep * obj_rep vr = self.visual_rep(last_outputs.fmap.detach(), rois, filtered_rel_inds[:, 1:]) prod_rep = prod_rep * vr rel_dists = self.rel_compress(prod_rep) if self.use_bias: rel_dists = rel_dists + self.freq_bias.index_with_labels( torch.stack(( obj_preds[filtered_rel_inds[:, 1]], obj_preds[filtered_rel_inds[:, 2]], ), 1)) return filtered_rel_inds, rm_obj_dists, obj_preds, rel_dists
def __init__(self, classes, rel_classes, mode='sgdet', use_vision=True, embed_dim=200, hidden_dim=256, obj_dim=2048, pooling_dim=2048, pooling_size=7, dropout_rate=0.2, use_bias=True, use_tanh=True, limit_vision=True, sl_pretrain=False, num_iter=-1, use_resnet=False, reduce_input=False, debug_type=None, post_nms_thresh=0.5): super(DynamicFilterContext, self).__init__() self.classes = classes self.rel_classes = rel_classes assert mode in MODES self.mode = mode self.use_vision = use_vision self.use_bias = use_bias self.use_tanh = use_tanh self.use_highway = True self.limit_vision = limit_vision self.pooling_dim = pooling_dim self.pooling_size = pooling_size self.nms_thresh = post_nms_thresh self.obj_compress = myNNLinear(self.pooling_dim, self.num_classes, bias=True) # self.roi_fmap_obj = load_vgg(pretrained=False).classifier roi_fmap_obj = [myNNLinear(512*self.pooling_size*self.pooling_size, 4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5), myNNLinear(4096, 4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5)] self.roi_fmap_obj = nn.Sequential(*roi_fmap_obj) if self.use_bias: self.freq_bias = FrequencyBias() self.reduce_dim = 256 self.reduce_obj_fmaps = nn.Conv2d(512, self.reduce_dim, kernel_size=1) similar_fun = [myNNLinear(self.reduce_dim*2, self.reduce_dim), nn.ReLU(inplace=True), myNNLinear(self.reduce_dim, 1)] self.similar_fun = nn.Sequential(*similar_fun) # roi_fmap = [Flattener(), # load_vgg(use_dropout=False, use_relu=False, use_linear=self.pooling_dim == 4096, pretrained=False).classifier,] # if self.pooling_dim != 4096: # roi_fmap.append(nn.Linear(4096, self.pooling_dim)) # self.roi_fmap = nn.Sequential(*roi_fmap) roi_fmap = [Flattener(), nn.Linear(self.reduce_dim*2*self.pooling_size*self.pooling_size, 4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(4096, 4096, bias=True)] self.roi_fmap = nn.Sequential(*roi_fmap) self.hidden_dim = hidden_dim self.rel_compress = myNNLinear(self.hidden_dim*3, self.num_rels) self.post_obj = myNNLinear(self.pooling_dim, self.hidden_dim*2) self.mapping_x = myNNLinear(self.hidden_dim*2, self.hidden_dim*3) self.reduce_rel_input = myNNLinear(self.pooling_dim, self.hidden_dim*3)
class DynamicFilterContext(nn.Module): def __init__(self, classes, rel_classes, mode='sgdet', use_vision=True, embed_dim=200, hidden_dim=256, obj_dim=2048, pooling_dim=2048, pooling_size=7, dropout_rate=0.2, use_bias=True, use_tanh=True, limit_vision=True, sl_pretrain=False, num_iter=-1, use_resnet=False, reduce_input=False, debug_type=None, post_nms_thresh=0.5): super(DynamicFilterContext, self).__init__() self.classes = classes self.rel_classes = rel_classes assert mode in MODES self.mode = mode self.use_vision = use_vision self.use_bias = use_bias self.use_tanh = use_tanh self.use_highway = True self.limit_vision = limit_vision self.pooling_dim = pooling_dim self.pooling_size = pooling_size self.nms_thresh = post_nms_thresh self.obj_compress = myNNLinear(self.pooling_dim, self.num_classes, bias=True) # self.roi_fmap_obj = load_vgg(pretrained=False).classifier roi_fmap_obj = [myNNLinear(512*self.pooling_size*self.pooling_size, 4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5), myNNLinear(4096, 4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5)] self.roi_fmap_obj = nn.Sequential(*roi_fmap_obj) if self.use_bias: self.freq_bias = FrequencyBias() self.reduce_dim = 256 self.reduce_obj_fmaps = nn.Conv2d(512, self.reduce_dim, kernel_size=1) similar_fun = [myNNLinear(self.reduce_dim*2, self.reduce_dim), nn.ReLU(inplace=True), myNNLinear(self.reduce_dim, 1)] self.similar_fun = nn.Sequential(*similar_fun) # roi_fmap = [Flattener(), # load_vgg(use_dropout=False, use_relu=False, use_linear=self.pooling_dim == 4096, pretrained=False).classifier,] # if self.pooling_dim != 4096: # roi_fmap.append(nn.Linear(4096, self.pooling_dim)) # self.roi_fmap = nn.Sequential(*roi_fmap) roi_fmap = [Flattener(), nn.Linear(self.reduce_dim*2*self.pooling_size*self.pooling_size, 4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(4096, 4096, bias=True)] self.roi_fmap = nn.Sequential(*roi_fmap) self.hidden_dim = hidden_dim self.rel_compress = myNNLinear(self.hidden_dim*3, self.num_rels) self.post_obj = myNNLinear(self.pooling_dim, self.hidden_dim*2) self.mapping_x = myNNLinear(self.hidden_dim*2, self.hidden_dim*3) self.reduce_rel_input = myNNLinear(self.pooling_dim, self.hidden_dim*3) def obj_feature_map(self, features, rois): feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)( features, rois) return feature_pool # return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) @property def is_sgdet(self): return self.mode == 'sgdet' @property def is_sgcls(self): return self.mode == 'sgcls' def forward(self, *args, **kwargs): results = self.base_forward(*args, **kwargs) return results def base_forward(self, fmaps, obj_logits, im_inds, rel_inds, msg_rel_inds, reward_rel_inds, im_sizes, boxes_priors=None, boxes_deltas=None, boxes_per_cls=None, obj_labels=None): assert self.mode == 'sgcls' num_objs = obj_logits.shape[0] num_rels = rel_inds.shape[0] rois = torch.cat((im_inds[:, None].float(), boxes_priors), 1) obj_fmaps = self.obj_feature_map(fmaps, rois) reduce_obj_fmaps = self.reduce_obj_fmaps(obj_fmaps) S_fmaps = reduce_obj_fmaps[rel_inds[:, 1]] O_fmaps = reduce_obj_fmaps[rel_inds[:, 2]] if conf.debug_type in ['test1_0']: last_SO_fmaps = torch.cat((S_fmaps, O_fmaps), dim=1) elif conf.debug_type in ['test1_1']: S_fmaps_trans = S_fmaps.view(num_rels, self.reduce_dim, self.pooling_size*self.pooling_size).transpose(2, 1) O_fmaps_trans = O_fmaps.view(num_rels, self.reduce_dim, self.pooling_size*self.pooling_size).transpose(2, 1) pooling_size_sq = self.pooling_size*self.pooling_size S_fmaps_extend = S_fmaps_trans.repeat(1, 1, pooling_size_sq).view(num_rels, pooling_size_sq*pooling_size_sq, self.reduce_dim) O_fmaps_extend = O_fmaps_trans.repeat(1, pooling_size_sq, 1) SO_fmaps_extend = torch.cat((S_fmaps_extend, O_fmaps_extend), dim=2) SO_fmaps_logits = self.similar_fun(SO_fmaps_extend) SO_fmaps_logits = SO_fmaps_logits.view(num_rels, pooling_size_sq, pooling_size_sq) # (first dim is S_fmaps, second dim is O_fmaps) SO_fmaps_scores = F.softmax(SO_fmaps_logits, dim=1) weighted_S_fmaps = torch.matmul(SO_fmaps_scores.transpose(2, 1), S_fmaps_trans) # (num_rels, 49, 49) x (num_rels, 49, self.reduce_dim) last_SO_fmaps = torch.cat((weighted_S_fmaps, O_fmaps_trans), dim=2) last_SO_fmaps = last_SO_fmaps.transpose(2, 1).contiguous().view(num_rels, self.reduce_dim*2, self.pooling_size, self.pooling_size) else: raise ValueError # for object classification obj_feats = self.roi_fmap_obj(obj_fmaps.view(rois.size(0), -1)) obj_logits = self.obj_compress(obj_feats) obj_dists = F.softmax(obj_logits, dim=1) pred_obj_cls = obj_dists[:, 1:].max(1)[1] + 1 # for relationship classification rel_input = self.roi_fmap(last_SO_fmaps) subobj_rep = self.post_obj(obj_feats) sub_rep = subobj_rep[:, :self.hidden_dim][rel_inds[:, 1]] obj_rep = subobj_rep[:, self.hidden_dim:][rel_inds[:, 2]] last_rel_input = self.reduce_rel_input(rel_input) last_obj_input = self.mapping_x(torch.cat((sub_rep, obj_rep), 1)) triple_rep = nn.ReLU(inplace=True)(last_obj_input + last_rel_input) - (last_obj_input - last_rel_input).pow(2) rel_logits = self.rel_compress(triple_rep) # follow neural-motifs paper if self.use_bias: if self.mode in ['sgcls', 'sgdet']: rel_logits = rel_logits + self.freq_bias.index_with_labels( torch.stack(( pred_obj_cls[rel_inds[:, 1]], pred_obj_cls[rel_inds[:, 2]], ), 1)) elif self.mode == 'predcls': rel_logits = rel_logits + self.freq_bias.index_with_labels( torch.stack(( obj_labels[rel_inds[:, 1]], obj_labels[rel_inds[:, 2]], ), 1)) else: raise NotImplementedError return pred_obj_cls, obj_logits, rel_logits
def __init__(self, train_data, mode='sgcls', require_overlap_det=True, use_bias=False, test_bias=False, backbone='vgg16', RELS_PER_IMG=1024, min_size=None, max_size=None, edge_model='motifs'): """ Base class for an SGG model :param mode: (sgcls, predcls, or sgdet) :param require_overlap_det: Whether two objects must intersect """ super(RelModelBase, self).__init__() self.classes = train_data.ind_to_classes self.rel_classes = train_data.ind_to_predicates self.mode = mode self.backbone = backbone self.RELS_PER_IMG = RELS_PER_IMG self.pool_sz = 7 self.stride = 16 self.use_bias = use_bias self.test_bias = test_bias self.require_overlap = require_overlap_det and self.mode == 'sgdet' if self.backbone == 'resnet50': self.obj_dim = 1024 self.fmap_sz = 21 if min_size is None: min_size = 1333 if max_size is None: max_size = 1333 print('\nLoading COCO pretrained model maskrcnn_resnet50_fpn...\n') # See https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html self.detector = torchvision.models.detection.maskrcnn_resnet50_fpn( pretrained=True, min_size=min_size, max_size=max_size, box_detections_per_img=50, box_score_thresh=0.2) in_features = self.detector.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one self.detector.roi_heads.box_predictor = FastRCNNPredictor( in_features, len(self.classes)) self.detector.roi_heads.mask_predictor = None layers = list(self.detector.roi_heads.children())[:2] self.roi_fmap_obj = copy.deepcopy(layers[1]) self.roi_fmap = copy.deepcopy(layers[1]) self.roi_pool = copy.deepcopy(layers[0]) elif self.backbone == 'vgg16': self.obj_dim = 4096 self.fmap_sz = 38 if min_size is None: min_size = IM_SCALE if max_size is None: max_size = IM_SCALE vgg = load_vgg(use_dropout=False, use_relu=False, use_linear=True, pretrained=False) vgg.features.out_channels = 512 anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), )) roi_pooler = torchvision.ops.MultiScaleRoIAlign( featmap_names=['0'], output_size=self.pool_sz, sampling_ratio=2) self.detector = FasterRCNN(vgg.features, min_size=min_size, max_size=max_size, rpn_anchor_generator=anchor_generator, box_head=TwoMLPHead( vgg.features.out_channels * self.pool_sz**2, self.obj_dim), box_predictor=FastRCNNPredictor( self.obj_dim, len(train_data.ind_to_classes)), box_roi_pool=roi_pooler, box_detections_per_img=50, box_score_thresh=0.2) self.roi_fmap = nn.Sequential(nn.Flatten(), vgg.classifier) self.roi_fmap_obj = load_vgg(pretrained=False).classifier self.roi_pool = copy.deepcopy( list(self.detector.roi_heads.children())[0]) else: raise NotImplementedError(self.backbone) self.edge_dim = self.detector.backbone.out_channels self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pool_sz, stride=self.stride, dim=self.edge_dim, edge_model=edge_model) if self.use_bias: self.freq_bias = FrequencyBias(train_data)
class RelModelLinknet(nn.Module): """ RELATIONSHIPS """ def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=4096, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModelLinknet, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.ctx_dim = 1024 if use_resnet else 512 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, pooling_dim=self.pooling_dim, ctx_dim=self.ctx_dim) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier # Global Context Encoding self.GCE = GlobalContextEncoding(num_classes=self.num_classes, ctx_dim=self.ctx_dim) ################################### # K2 self.pos_embed = nn.Sequential(*[ nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0), nn.Linear(4, 128), nn.ReLU(inplace=True), nn.Dropout(0.1), ]) # fc4 self.rel_compress = nn.Linear(self.pooling_dim + 128, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) def visual_rep(self, features, rois, pair_inds): """ Classify the features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds inds to use when predicting :return: score_pred, a [num_rois, num_classes] array box_pred, a [num_rois, num_classes, 4] array """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat( (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def geo_layout_enc(self, box_priors, rel_inds): """ Geometric Layout Encoding :param box_priors: [num_rois, 4] of (xmin, ymin, xmax, ymax) :param rel_inds: [num_rels, 3] of (img ind, box0 ind, box1 ind) :return: bos: [num_rois*(num_rois-1), 4] encoded relative geometric layout: bo|s """ cxcywh = center_size(box_priors.data) # convert to (cx, cy, w, h) box_s = cxcywh[rel_inds[:, 1]] box_o = cxcywh[rel_inds[:, 2]] # relative location rlt_loc_x = torch.div((box_o[:, 0] - box_s[:, 0]), box_s[:, 2]).view(-1, 1) rlt_loc_y = torch.div((box_o[:, 1] - box_s[:, 1]), box_s[:, 3]).view(-1, 1) # scale information scl_info_w = torch.log(torch.div(box_o[:, 2], box_s[:, 2])).view(-1, 1) scl_info_h = torch.log(torch.div(box_o[:, 3], box_s[:, 3])).view(-1, 1) bos = torch.cat((rlt_loc_x, rlt_loc_y, scl_info_w, scl_info_h), 1) return bos def glb_context_enc(self, features, im_inds, gt_classes, image_offset): """ Global Context Encoding :param features: [batch_size, ctx_dim, IM_SIZE/4, IM_SIZE/4] fmap features :param im_ind: [num_rois] image index :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0) :return: context_features: [num_rois, ctx_dim] stacked context_feature c according to im_ind gce_obj_dists: [batch_size, num_classes] softmax of predicted multi-label distribution: M gce_obj_labels: [batch_size, num_classes] ground truth multi-labels """ context_feature, gce_obj_dists = self.GCE(features) context_features = context_feature[im_inds] gce_obj_labels = torch.zeros_like(gce_obj_dists) gce_obj_labels[gt_classes[:, 0] - image_offset, gt_classes[:, 1]] = 1 return context_features, gce_obj_dists, gce_obj_labels def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): """ Forward pass for relationship :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE] :param im_sizes: A numpy array of (h, w, scale) for each image. :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0) :param gt_boxes: Training parameters: :param gt_boxes: [num_gt, 4] GT boxes over the batch. :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will be used to compute the training loss. Each (img_ind, fpn_idx) :return: If train: scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels if test: prob dists, boxes, img inds, maxscores, classes """ result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) if result.is_none(): return ValueError("heck") im_inds = result.im_inds - image_offset boxes = result.rm_box_priors if self.training and result.rel_labels is None: assert self.mode == 'sgdet' result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) rois = torch.cat((im_inds[:, None].float(), boxes), 1) result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # c M context_features, result.gce_obj_dists, result.gce_obj_labels = self.glb_context_enc( result.fmap.detach(), im_inds.data, gt_classes.data, image_offset) # Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists, result.obj_preds, edge_rep = self.context( result.obj_fmap, result.rm_obj_dists.detach(), context_features.detach(), result.rm_obj_labels if self.training or self.mode == 'predcls' else None, result.boxes_all) # Split into subject and object representations edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim) # E1 subj_rep = edge_rep[:, 0] # E1_s obj_rep = edge_rep[:, 1] # E1_o prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]] # G0 if self.use_vision: vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:]) # F if self.limit_vision: # exact value TBD prod_rep = torch.cat( (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1) else: prod_rep = prod_rep * vr if self.use_tanh: prod_rep = F.tanh(prod_rep) bos = self.geo_layout_enc(boxes, rel_inds) # bo|s pos_embed = self.pos_embed(Variable(bos)) result.rel_dists = self.rel_compress( torch.cat((prod_rep, pos_embed), 1)) # G2 if self.use_bias: result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels( torch.stack(( result.obj_preds[rel_inds[:, 1]], result.obj_preds[rel_inds[:, 2]], ), 1)) if self.training: return result twod_inds = arange( result.obj_preds.data) * self.num_classes + result.obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view( result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) return filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep) def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus))) outputs = nn.parallel.parallel_apply( replicas, [batch[i] for i in range(self.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, model_path='', reachability=False, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, init_center=False, limit_vision=True): super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.init_center = init_center self.pooling_size = 7 self.model_path = model_path self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.centroids = None self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.global_embedding = EmbeddingImagenet(4096) self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim) self.meta_classify = MetaEmbedding_Classifier( feat_dim=self.pooling_dim, num_classes=self.num_rels) self.disc_center_g = DiscCentroidsLoss(self.num_classes, self.pooling_dim) self.meta_classify_g = MetaEmbedding_Classifier( feat_dim=self.pooling_dim, num_classes=self.num_classes) self.global_sub_additive = nn.Linear(4096, 1, bias=True) self.global_obj_additive = nn.Linear(4096, 1, bias=True) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() self.global_logist = nn.Linear(self.pooling_dim, self.num_classes, bias=True) # CosineLinear(4096,150)# self.global_logist.weight = torch.nn.init.xavier_normal( self.global_logist.weight, gain=1.0) self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() self.class_num = torch.zeros(len(self.classes)) self.centroids = torch.zeros(len(self.classes), self.pooling_dim).cuda()
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=4096, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModelLinknet, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.ctx_dim = 1024 if use_resnet else 512 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, pooling_dim=self.pooling_dim, ctx_dim=self.ctx_dim) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier # Global Context Encoding self.GCE = GlobalContextEncoding(num_classes=self.num_classes, ctx_dim=self.ctx_dim) ################################### # K2 self.pos_embed = nn.Sequential(*[ nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0), nn.Linear(4, 128), nn.ReLU(inplace=True), nn.Dropout(0.1), ]) # fc4 self.rel_compress = nn.Linear(self.pooling_dim + 128, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias()
class RelModel(nn.Module): """ RELATIONSHIPS """ def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ Args: classes: list, list of 151 object class names(including background) rel_classes: list, list of 51 predicate names( including background(norelationship)) mode: string, 'sgdet', 'predcls' or 'sgcls' num_gpus: integer, number of GPUs to use use_vision: boolean, whether to use vision in the final product require_overlap_det: boolean, whether two object must intersect embed_dim: integer, number of dimension for all embeddings hidden_dim: integer, hidden size of LSTM pooling_dim: integer, outputsize of vgg fc layer nl_obj: integer, number of object context layer, 2 in paper nl_edge: integer, number of edge context layer, 4 in paper use_resnet: integer, use resnet for backbone order: string, value must be in ('size', 'confidence', 'random', 'leftright'), order of RoIs thresh: float, threshold for scores of boxes if score of box smaller than thresh, then it will be abandoned use_proposals: boolean, whether to use proposals pass_in_obj_feats_to_decoder: boolean, whether to pass object features to decoder RNN pass_in_obj_feats_to_edge: boolean, whether to pass object features to edge context RNN rec_dropout: float, dropout rate in RNN use_bias: boolean, use_tanh: boolean, limit_vision: boolean, """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) def visual_rep(self, features, rois, pair_inds): """ Classify the features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds inds to use when predicting :return: score_pred, a [num_rois, num_classes] array box_pred, a [num_rois, num_classes, 4] array """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat( (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): """Forward pass for detection Args: x: Images@[batch_size, 3, IM_SIZE, IM_SIZE] im_sizes: A numpy array of (h, w, scale) for each image. image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0) Training parameters: gt_boxes: [num_gt, 4] GT boxes over the batch. gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) gt_rels: proposals: train_anchor_inds: a [num_train, 2] array of indices for the anchors that will be used to compute the training loss. Each (img_ind, fpn_idx) return_fmap: Returns: If train: scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels If test: prob dists, boxes, img inds, maxscores, classes """ result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) """ Results attributes: od_obj_dists: digits after score_fc in RCNN rm_obj_dists: od_obj_dists after nms obj_scores: nmn obj_preds=None, obj_fmap=None, od_box_deltas=None, rm_box_deltas=None, od_box_targets=None, rm_box_targets=None, od_box_priors: proposal before nms rm_box_priors: proposal after nms boxes_assigned=None, boxes_all=None, od_obj_labels=None, rm_obj_labels=None, rpn_scores=None, rpn_box_deltas=None, rel_labels=None, im_inds: image index of every proposals fmap=None, rel_dists=None, rel_inds=None, rel_rep=None one example: sgcls task: result.fmap: torch.Size([6, 512, 37, 37]) result.im_inds: torch.Size([44]) result.obj_fmap: torch.Size([44, 4096]) result.od_box_priors: torch.Size([44, 4]) result.od_obj_dists: torch.Size([44, 151]) result.od_obj_labels: torch.Size([44]) result.rel_labels: torch.Size([316, 4]) result.rm_box_priors: torch.Size([44, 4]) result.rm_obj_dists: torch.Size([44, 151]) result.rm_obj_labels: torch.Size([44]) """ if result.is_none(): return ValueError("heck") # image_offset refer to Blob # self.batch_size_per_gpu * index im_inds = result.im_inds - image_offset boxes = result.rm_box_priors #embed(header='rel_model.py before rel_assignments') if self.training and result.rel_labels is None: assert self.mode == 'sgdet' # only in sgdet mode # shapes: # im_inds: (box_num,) # boxes: (box_num, 4) # rm_obj_labels: (box_num,) # gt_boxes: (box_num, 4) # gt_classes: (box_num, 2) maybe[im_ind, class_ind] # gt_rels: (rel_num, 4) # image_offset: integer result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) #embed(header='rel_model.py after rel_assignments') # rel_labels[:, :3] if sgcls rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) rois = torch.cat((im_inds[:, None].float(), boxes), 1) # obj_fmap: (NumOfRoI, 4096) # RoIAlign result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists, result.obj_preds, edge_ctx = self.context( result.obj_fmap, result.rm_obj_dists.detach(), im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None, boxes.data, result.boxes_all) if edge_ctx is None: edge_rep = self.post_emb(result.obj_preds) else: edge_rep = self.post_lstm(edge_ctx) # Split into subject and object representations edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim) subj_rep = edge_rep[:, 0] obj_rep = edge_rep[:, 1] prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]] # embed(header='rel_model.py prod_rep') if self.use_vision: vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:]) if self.limit_vision: # exact value TBD prod_rep = torch.cat( (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1) else: prod_rep = prod_rep * vr if self.use_tanh: prod_rep = F.tanh(prod_rep) result.rel_dists = self.rel_compress(prod_rep) if self.use_bias: result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels( torch.stack(( result.obj_preds[rel_inds[:, 1]], result.obj_preds[rel_inds[:, 2]], ), 1)) #embed(header='rel model return ') if self.training: # embed(header='rel_model.py before return') # what will be useful: # rm_obj_dists, rm_obj_labels # rel_labels, rel_dists return result twod_inds = arange( result.obj_preds.data) * self.num_classes + result.obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view( result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) #embed(header='rel_model.py before return') return filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep) def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus))) outputs = nn.parallel.parallel_apply( replicas, [batch[i] for i in range(self.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs
class RelModel(nn.Module): """ RELATIONSHIPS """ def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, gnn=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.gnn = gnn self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.global_embedding = EmbeddingImagenet(4096) self.global_logist = nn.Linear(4096, 151, bias=True) # CosineLinear(4096,150)# self.global_logist.weight = torch.nn.init.xavier_normal( self.global_logist.weight, gain=1.0) # self.global_rel_logist = nn.Linear(4096, 50 , bias=True) # self.global_rel_logist.weight = torch.nn.init.xavier_normal(self.global_rel_logist.weight, gain=1.0) # self.global_logist = CosineLinear(4096,150) self.global_sub_additive = nn.Linear(4096, 1, bias=True) self.global_obj_additive = nn.Linear(4096, 1, bias=True) self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) self.edge_coordinate_embedding = nn.Sequential(*[ nn.BatchNorm1d(5, momentum=BATCHNORM_MOMENTUM / 10.0), nn.Linear(5, 256), nn.ReLU(inplace=True), nn.Dropout(0.1), ]) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(4096 + 256, 51, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) self.node_transform = nn.Linear(4096, 256, bias=True) self.edge_transform = nn.Linear(4096, 256, bias=True) # self.rel_compress = CosineLinear(self.pooling_dim+256, self.num_rels) # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() if self.gnn: self.graph_network_node = GraphNetwork(4096) self.graph_network_edge = GraphNetwork() if self.training: self.graph_network_node.train() self.graph_network_edge.train() else: self.graph_network_node.eval() self.graph_network_edge.eval() self.edge_sim_network = nn.Linear(4096, 1, bias=True) self.metric_net = MetricLearning() @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) def visual_rep(self, features, rois, pair_inds): """ Classify the features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds inds to use when predicting :return: score_pred, a [num_rois, num_classes] array box_pred, a [num_rois, num_classes, 4] array """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat( (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def coordinate_feats(self, boxes, rel_inds): coordinate_rep = {} coordinate_rep['center'] = center_size(boxes) coordinate_rep['point'] = torch.cat( (boxes, coordinate_rep['center'][:, 2:]), 1) sub_coordnate = {} sub_coordnate['center'] = coordinate_rep['center'][rel_inds[:, 1]] sub_coordnate['point'] = coordinate_rep['point'][rel_inds[:, 1]] obj_coordnate = {} obj_coordnate['center'] = coordinate_rep['center'][rel_inds[:, 2]] obj_coordnate['point'] = coordinate_rep['point'][rel_inds[:, 2]] edge_of_coordinate_rep = torch.zeros(sub_coordnate['center'].size(0), 5).cuda().float() edge_of_coordinate_rep[:, 0] = (sub_coordnate['point'][:, 0] - obj_coordnate['center'][:, 0]) * 1.0 / \ obj_coordnate['center'][:, 2] edge_of_coordinate_rep[:, 1] = (sub_coordnate['point'][:, 1] - obj_coordnate['center'][:, 1]) * 1.0 / \ obj_coordnate['center'][:, 3] edge_of_coordinate_rep[:, 2] = (sub_coordnate['point'][:, 2] - obj_coordnate['center'][:, 0]) * 1.0 / \ obj_coordnate['center'][:, 2] edge_of_coordinate_rep[:, 3] = (sub_coordnate['point'][:, 3] - obj_coordnate['center'][:, 1]) * 1.0 / \ obj_coordnate['center'][:, 3] edge_of_coordinate_rep[:, 4] = sub_coordnate['point'][:, 4] * sub_coordnate['point'][:, 5] * 1.0 / \ obj_coordnate['center'][:, 2] \ / obj_coordnate['center'][:, 3] return edge_of_coordinate_rep def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) if result.is_none(): return ValueError("heck") im_inds = result.im_inds - image_offset boxes = result.rm_box_priors if self.training and result.rel_labels is None: assert self.mode == 'sgdet' result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) rois = torch.cat((im_inds[:, None].float(), boxes), 1) global_feature = self.global_embedding(result.fmap.detach()) result.global_dists = self.global_logist(global_feature) # print(result.global_dists) # result.global_rel_dists = F.sigmoid(self.global_rel_logist(global_feature)) result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists, result.obj_preds, node_rep0 = self.context( result.obj_fmap, result.rm_obj_dists.detach(), im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None, boxes.data, result.boxes_all) one_hot_multi = torch.zeros( (result.global_dists.shape[0], self.num_classes)) one_hot_multi[im_inds, result.rm_obj_labels] = 1.0 result.multi_hot = one_hot_multi.float().cuda() edge_rep = node_rep0.repeat(1, 2) edge_rep = edge_rep.view(edge_rep.size(0), 2, -1) global_feature_re = global_feature[im_inds] subj_global_additive_attention = F.relu( self.global_sub_additive(edge_rep[:, 0] + global_feature_re)) obj_global_additive_attention = F.relu( torch.sigmoid( self.global_obj_additive(edge_rep[:, 1] + global_feature_re))) subj_rep = edge_rep[:, 0] + subj_global_additive_attention * global_feature_re obj_rep = edge_rep[:, 1] + obj_global_additive_attention * global_feature_re edge_of_coordinate_rep = self.coordinate_feats(boxes.data, rel_inds) e_ij_coordinate_rep = self.edge_coordinate_embedding( edge_of_coordinate_rep) union_rep = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:]) edge_feat_init = union_rep prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[ rel_inds[:, 2]] * edge_feat_init prod_rep = torch.cat((prod_rep, e_ij_coordinate_rep), 1) if self.use_tanh: prod_rep = F.tanh(prod_rep) result.rel_dists = self.rel_compress(prod_rep) if self.use_bias: result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels( torch.stack(( result.obj_preds[rel_inds[:, 1]], result.obj_preds[rel_inds[:, 2]], ), 1)) if self.training: return result twod_inds = arange( result.obj_preds.data) * self.num_classes + result.obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view( result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) return filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep) def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus))) outputs = nn.parallel.parallel_apply( replicas, [batch[i] for i in range(self.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ Args: classes: list, list of 151 object class names(including background) rel_classes: list, list of 51 predicate names( including background(norelationship)) mode: string, 'sgdet', 'predcls' or 'sgcls' num_gpus: integer, number of GPUs to use use_vision: boolean, whether to use vision in the final product require_overlap_det: boolean, whether two object must intersect embed_dim: integer, number of dimension for all embeddings hidden_dim: integer, hidden size of LSTM pooling_dim: integer, outputsize of vgg fc layer nl_obj: integer, number of object context layer, 2 in paper nl_edge: integer, number of edge context layer, 4 in paper use_resnet: integer, use resnet for backbone order: string, value must be in ('size', 'confidence', 'random', 'leftright'), order of RoIs thresh: float, threshold for scores of boxes if score of box smaller than thresh, then it will be abandoned use_proposals: boolean, whether to use proposals pass_in_obj_feats_to_decoder: boolean, whether to pass object features to decoder RNN pass_in_obj_feats_to_edge: boolean, whether to pass object features to edge context RNN rec_dropout: float, dropout rate in RNN use_bias: boolean, use_tanh: boolean, limit_vision: boolean, """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias()
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision=limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' # print('REL MODEL CONSTRUCTOR: 1') self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) # print('REL MODEL CONSTRUCTOR: 2') self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) # print('REL MODEL CONSTRUCTOR: 3') if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier # print('REL MODEL CONSTRUCTOR: 4') ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() # print('REL MODEL CONSTRUCTOR: 5') if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias()
class RelModel(nn.Module): """ RELATIONSHIPS """ def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.1, use_bias=True, use_tanh=True, use_encoded_box=True, use_rl_tree=True, draw_tree=False, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.co_occour = np.load(CO_OCCOUR_PATH) self.co_occour = self.co_occour / self.co_occour.sum() self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.use_encoded_box = use_encoded_box self.use_rl_tree = use_rl_tree self.draw_tree = draw_tree self.limit_vision=limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.rl_train = False self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, use_rl_tree = self.use_rl_tree ) self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge, use_rl_tree=self.use_rl_tree, draw_tree = self.draw_tree) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(use_dropout=False, pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.hidden_dim * 2) self.post_cat = nn.Linear(self.hidden_dim * 2, self.pooling_dim) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() self.post_cat.weight = torch.nn.init.xavier_normal(self.post_cat.weight, gain=1.0) self.post_cat.bias.data.zero_() if self.use_encoded_box: # encode spatial info self.encode_spatial_1 = nn.Linear(32, 512) self.encode_spatial_2 = nn.Linear(512, self.pooling_dim) self.encode_spatial_1.weight.data.normal_(0, 1.0) self.encode_spatial_1.bias.data.zero_() self.encode_spatial_2.weight.data.normal_(0, 0.1) self.encode_spatial_2.bias.data.zero_() if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) def visual_rep(self, features, rois, pair_inds): """ Classify the features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :param pair_inds inds to use when predicting :return: score_pred, a [num_rois, num_classes] array box_pred, a [num_rois, num_classes, 4] array """ assert pair_inds.size(1) == 2 uboxes = self.union_boxes(features, rois, pair_inds) return self.roi_fmap(uboxes) def get_rel_inds(self, rel_labels, im_inds, box_priors): # Get the relationship candidates if self.training and not self.use_rl_tree: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds def obj_feature_map(self, features, rois): """ Gets the ROI features :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2) :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1]. :return: [num_rois, #dim] array """ feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)( features, rois) return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1)) def get_rel_label(self, im_inds, gt_rels, rel_inds): np_im_inds = im_inds.data.cpu().numpy() np_gt_rels = gt_rels.long().data.cpu().numpy() np_rel_inds = rel_inds.long().cpu().numpy() num_obj = int(im_inds.shape[0]) sub_id = np_rel_inds[:, 1] obj_id = np_rel_inds[:, 2] select_id = sub_id * num_obj + obj_id count = 0 offset = 0 slicedInds = np.where(np_im_inds == count)[0] label = np.array([0]*num_obj*num_obj, dtype=int) while(len(slicedInds) > 0): slice_len = len(slicedInds) selectInds = np.where(np_gt_rels[:,0] == count)[0] slicedRels = np_gt_rels[selectInds,:] flattenID = (slicedRels[:,1] + offset) * num_obj + (slicedRels[:,2] + offset) slicedLabel = slicedRels[:,3] label[flattenID] = slicedLabel count += 1 offset += slice_len slicedInds = np.where(np_im_inds == count)[0] return Variable(torch.from_numpy(label[select_id]).long().cuda()) def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): """ Forward pass for detection :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE] :param im_sizes: A numpy array of (h, w, scale) for each image. :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0) :param gt_boxes: Training parameters: :param gt_boxes: [num_gt, 4] GT boxes over the batch. :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will be used to compute the training loss. Each (img_ind, fpn_idx) :return: If train: scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels if test: prob dists, boxes, img inds, maxscores, classes """ result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) if result.is_none(): return ValueError("heck") im_inds = result.im_inds - image_offset boxes = result.rm_box_priors if self.training and result.rel_labels is None: assert self.mode == 'sgdet' result.rel_labels, fg_rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) #if self.training and (not self.use_rl_tree): # generate arbitrary forest according to graph # arbitrary_forest = graph_to_trees(self.co_occour, result.rel_labels, gt_classes) #else: arbitrary_forest = None rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) if self.use_rl_tree: result.rel_label_tkh = self.get_rel_label(im_inds, gt_rels, rel_inds) rois = torch.cat((im_inds[:, None].float(), boxes), 1) result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # whole image feature, used for virtual node batch_size = result.fmap.shape[0] image_rois = Variable(torch.randn(batch_size, 5).fill_(0).cuda()) for i in range(batch_size): image_rois[i, 0] = i image_rois[i, 1] = 0 image_rois[i, 2] = 0 image_rois[i, 3] = IM_SCALE image_rois[i, 4] = IM_SCALE image_fmap = self.obj_feature_map(result.fmap.detach(), image_rois) if self.mode != 'sgdet' and self.training: fg_rel_labels = result.rel_labels # Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists, result.obj_preds, edge_ctx, result.gen_tree_loss, result.entropy_loss, result.pair_gate, result.pair_gt = self.context( result.obj_fmap, result.rm_obj_dists.detach(), im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None, boxes.data, result.boxes_all, arbitrary_forest, image_rois, image_fmap, self.co_occour, fg_rel_labels if self.training else None, x) if edge_ctx is None: edge_rep = self.post_emb(result.obj_preds) else: edge_rep = self.post_lstm(edge_ctx) # Split into subject and object representations edge_rep = edge_rep.view(edge_rep.size(0), 2, self.hidden_dim) subj_rep = edge_rep[:, 0] obj_rep = edge_rep[:, 1] prod_rep = torch.cat((subj_rep[rel_inds[:, 1]], obj_rep[rel_inds[:, 2]]), 1) prod_rep = self.post_cat(prod_rep) if self.use_encoded_box: # encode spatial info assert(boxes.shape[1] == 4) # encoded_boxes: [box_num, (x1,y1,x2,y2,cx,cy,w,h)] encoded_boxes = tree_utils.get_box_info(boxes) # encoded_boxes_pair: [batch_szie, (box1, box2, unionbox, intersectionbox)] encoded_boxes_pair = tree_utils.get_box_pair_info(encoded_boxes[rel_inds[:, 1]], encoded_boxes[rel_inds[:, 2]]) # encoded_spatial_rep spatial_rep = F.relu(self.encode_spatial_2(F.relu(self.encode_spatial_1(encoded_boxes_pair)))) # element-wise multiply with prod_rep prod_rep = prod_rep * spatial_rep if self.use_vision: vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:]) if self.limit_vision: # exact value TBD prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) else: prod_rep = prod_rep * vr if self.use_tanh: prod_rep = F.tanh(prod_rep) result.rel_dists = self.rel_compress(prod_rep) if self.use_bias: result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack(( result.obj_preds[rel_inds[:, 1]], result.obj_preds[rel_inds[:, 2]], ), 1)) if self.training and (not self.rl_train): return result twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) if not self.rl_train: return filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep, gt_boxes, gt_classes, gt_rels) else: return result, filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep, gt_boxes, gt_classes, gt_rels) def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus))) outputs = nn.parallel.parallel_apply(replicas, [batch[i] for i in range(self.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=False, require_overlap_det=True, embed_dim=200, hidden_dim=4096, use_resnet=False, thresh=0.01, use_proposals=False, use_bias=True, limit_vision=True, depth_model=None, pretrained_depth=False, active_features=None, frozen_features=None, use_embed=False, **kwargs): """ :param classes: object classes :param rel_classes: relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: enable the contribution of union of bounding boxes :param require_overlap_det: whether two objects must intersect :param embed_dim: word2vec embeddings dimension :param hidden_dim: dimension of the fusion hidden layer :param use_resnet: use resnet as faster-rcnn's backbone :param thresh: faster-rcnn related threshold (Threshold for calling it a good box) :param use_proposals: whether to use region proposal candidates :param use_bias: enable frequency bias :param limit_vision: use truncated version of UoBB features :param depth_model: provided architecture for depth feature extraction :param pretrained_depth: whether the depth feature extractor should be initialized with ImageNet weights :param active_features: what set of features should be enabled (e.g. 'vdl' : visual, depth, and location features) :param frozen_features: what set of features should be frozen (e.g. 'd' : depth) :param use_embed: use word2vec embeddings """ RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus, require_overlap_det, active_features, frozen_features) self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.use_vision = use_vision self.use_bias = use_bias self.limit_vision = limit_vision # -- Store depth related parameters assert depth_model in DEPTH_MODELS self.depth_model = depth_model self.pretrained_depth = pretrained_depth self.depth_pooling_dim = DEPTH_DIMS[self.depth_model] self.use_embed = use_embed self.detector = nn.Module() features_size = 0 # -- Check whether ResNet is selected as faster-rcnn's backbone if use_resnet: raise ValueError( "The current model does not support ResNet as the Faster-RCNN's backbone." ) """ *** DIFFERENT COMPONENTS OF THE PROPOSED ARCHITECTURE *** This is the part where the different components of the proposed relation detection architecture are defined. In the case of RGB images, we have class probability distribution features, visual features, and the location ones. If we are considering depth images as well, we augment depth features too. """ # -- Visual features if self.has_visual: # -- Define faster R-CNN network and it's related feature extractors self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.roi_fmap_obj = load_vgg(pretrained=False).classifier # -- Define union features if self.use_vision: # -- UoBB pooling module self.union_boxes = UnionBoxesAndFeats( pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) # -- UoBB feature extractor roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=self.hidden_dim == 4096, pretrained=False).classifier, ] if self.hidden_dim != 4096: roi_fmap.append(nn.Linear(4096, self.hidden_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) # -- Define visual features hidden layer self.visual_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(self.obj_dim * 2, self.FC_SIZE_VISUAL)), nn.ReLU(inplace=True), nn.Dropout(0.8) ]) self.visual_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_VISUAL # -- Location features if self.has_loc: # -- Define location features hidden layer self.location_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(self.LOC_INPUT_SIZE, self.FC_SIZE_LOC)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) self.location_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_LOC # -- Class features if self.has_class: if self.use_embed: # -- Define class embeddings embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim) self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim) self.obj_embed.weight.data = embed_vecs.clone() classme_input_dim = self.embed_dim if self.use_embed else self.num_classes # -- Define Class features hidden layer self.classme_hlayer = nn.Sequential(*[ xavier_init( nn.Linear(classme_input_dim * 2, self.FC_SIZE_CLASS)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) self.classme_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_CLASS # -- Depth features if self.has_depth: # -- Initialize depth backbone self.depth_backbone = DepthCNN(depth_model=self.depth_model, pretrained=self.pretrained_depth) # -- Create a relation head which is used to carry on the feature extraction # from RoIs of depth features self.depth_rel_head = self.depth_backbone.get_classifier() # -- Define depth features hidden layer self.depth_rel_hlayer = nn.Sequential(*[ xavier_init( nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)), nn.ReLU(inplace=True), nn.Dropout(0.6), ]) self.depth_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_DEPTH # -- Initialize frequency bias if needed if self.use_bias: self.freq_bias = FrequencyBias() # -- *** Fusion layer *** -- # -- A hidden layer for concatenated features (fusion features) self.fusion_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(features_size, self.hidden_dim)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) # -- Final FC layer which predicts the relations self.rel_out = xavier_init( nn.Linear(self.hidden_dim, self.num_rels, bias=True)) # -- Freeze the user specified features if self.frz_visual: self.freeze_module(self.detector) self.freeze_module(self.roi_fmap_obj) self.freeze_module(self.visual_hlayer) if self.use_vision: self.freeze_module(self.roi_fmap) self.freeze_module(self.union_boxes.conv) if self.frz_class: self.freeze_module(self.classme_hlayer) if self.frz_loc: self.freeze_module(self.location_hlayer) if self.frz_depth: self.freeze_module(self.depth_backbone) self.freeze_module(self.depth_rel_head) self.freeze_module(self.depth_rel_hlayer)
class DecoderRNN(torch.nn.Module): def __init__(self, classes, rel_classes, embed_dim, obj_dim, inputs_dim, hidden_dim, pooling_dim, recurrent_dropout_probability=0.2, use_highway=True, use_input_projection_bias=True, use_vision=True, use_bias=True, use_tanh=True, limit_vision=True, sl_pretrain=False, num_iter=-1): """ Initializes the RNN :param embed_dim: Dimension of the embeddings :param encoder_hidden_dim: Hidden dim of the encoder, for attention purposes :param hidden_dim: Hidden dim of the decoder :param vocab_size: Number of words in the vocab :param bos_token: To use during decoding (non teacher forcing mode)) :param bos: beginning of sentence token :param unk: unknown token (not used) """ super(DecoderRNN, self).__init__() self.rel_embedding_dim = 100 self.classes = classes self.rel_classes = rel_classes embed_vecs = obj_edge_vectors(['start'] + self.classes, wv_dim=100) self.obj_embed = nn.Embedding(len(self.classes), embed_dim) self.obj_embed.weight.data = embed_vecs embed_rels = obj_edge_vectors(self.rel_classes, wv_dim=self.rel_embedding_dim) self.rel_embed = nn.Embedding(len(self.rel_classes), self.rel_embedding_dim) self.rel_embed.weight.data = embed_rels self.embed_dim = embed_dim self.obj_dim = obj_dim self.hidden_size = hidden_dim self.inputs_dim = inputs_dim self.pooling_dim = pooling_dim self.nms_thresh = 0.3 self.use_vision = use_vision self.use_bias = use_bias self.use_tanh = use_tanh self.limit_vision = limit_vision self.sl_pretrain = sl_pretrain self.num_iter = num_iter self.recurrent_dropout_probability = recurrent_dropout_probability self.use_highway = use_highway # We do the projections for all the gates all at once, so if we are # using highway layers, we need some extra projections, which is # why the sizes of the Linear layers change here depending on this flag. if use_highway: self.input_linearity = torch.nn.Linear( self.input_size, 6 * self.hidden_size, bias=use_input_projection_bias) self.state_linearity = torch.nn.Linear(self.hidden_size, 5 * self.hidden_size, bias=True) else: self.input_linearity = torch.nn.Linear( self.input_size, 4 * self.hidden_size, bias=use_input_projection_bias) self.state_linearity = torch.nn.Linear(self.hidden_size, 4 * self.hidden_size, bias=True) # self.obj_in_lin = torch.nn.Linear(self.rel_embedding_dim, self.rel_embedding_dim, bias=True) self.out = nn.Linear(self.hidden_size, len(self.classes)) self.reset_parameters() # For relation predication embed_vecs2 = obj_edge_vectors(self.classes, wv_dim=embed_dim) self.obj_embed2 = nn.Embedding(self.num_classes, embed_dim) self.obj_embed2.weight.data = embed_vecs2.clone() # self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) self.post_lstm = nn.Linear(self.obj_dim + 2 * self.embed_dim + 128, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_size) ) ######## there may need more consideration self.post_lstm.bias.data.zero_() self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() # simple relation model from dataloaders.visual_genome import VG from lib.get_dataset_counts import get_counts, box_filter fg_matrix, bg_matrix = get_counts(train_data=VG.splits( num_val_im=5000, filter_non_overlap=True, filter_duplicate_rels=True, use_proposals=False)[0], must_overlap=True) prob_matrix = fg_matrix.astype(np.float32) prob_matrix[:, :, 0] = bg_matrix # TRYING SOMETHING NEW. prob_matrix[:, :, 0] += 1 prob_matrix /= np.sum(prob_matrix, 2)[:, :, None] # prob_matrix /= float(fg_matrix.max()) prob_matrix[:, :, 0] = 0 # Zero out BG self.prob_matrix = prob_matrix @property def num_classes(self): return len(self.classes) @property def num_rels(self): return len(self.rel_classes) @property def input_size(self): return self.inputs_dim + self.obj_embed.weight.size(1) def reset_parameters(self): # Use sensible default initializations for parameters. block_orthogonal(self.input_linearity.weight.data, [self.hidden_size, self.input_size]) block_orthogonal(self.state_linearity.weight.data, [self.hidden_size, self.hidden_size]) self.state_linearity.bias.data.fill_(0.0) # Initialize forget gate biases to 1.0 as per An Empirical # Exploration of Recurrent Network Architectures, (Jozefowicz, 2015). self.state_linearity.bias.data[self.hidden_size:2 * self.hidden_size].fill_(1.0) def lstm_equations(self, timestep_input, previous_state, previous_memory, dropout_mask=None): """ Does the hairy LSTM math :param timestep_input: :param previous_state: :param previous_memory: :param dropout_mask: :return: """ # Do the projections for all the gates all at once. projected_input = self.input_linearity(timestep_input) projected_state = self.state_linearity(previous_state) # Main LSTM equations using relevant chunks of the big linear # projections of the hidden state and inputs. input_gate = torch.sigmoid( projected_input[:, 0 * self.hidden_size:1 * self.hidden_size] + projected_state[:, 0 * self.hidden_size:1 * self.hidden_size]) forget_gate = torch.sigmoid( projected_input[:, 1 * self.hidden_size:2 * self.hidden_size] + projected_state[:, 1 * self.hidden_size:2 * self.hidden_size]) memory_init = torch.tanh( projected_input[:, 2 * self.hidden_size:3 * self.hidden_size] + projected_state[:, 2 * self.hidden_size:3 * self.hidden_size]) output_gate = torch.sigmoid( projected_input[:, 3 * self.hidden_size:4 * self.hidden_size] + projected_state[:, 3 * self.hidden_size:4 * self.hidden_size]) memory = input_gate * memory_init + forget_gate * previous_memory timestep_output = output_gate * torch.tanh(memory) if self.use_highway: highway_gate = torch.sigmoid( projected_input[:, 4 * self.hidden_size:5 * self.hidden_size] + projected_state[:, 4 * self.hidden_size:5 * self.hidden_size]) highway_input_projection = projected_input[:, 5 * self.hidden_size:6 * self.hidden_size] timestep_output = highway_gate * timestep_output + ( 1 - highway_gate) * highway_input_projection # Only do dropout if the dropout prob is > 0.0 and we are in training mode. if dropout_mask is not None and self.training: timestep_output = timestep_output * dropout_mask return timestep_output, memory def get_rel_dist(self, obj_preds, obj_feats, rel_inds, vr=None): obj_embed2 = self.obj_embed2(obj_preds) edge_ctx = torch.cat((obj_embed2, obj_feats), 1) edge_rep = self.post_lstm(edge_ctx) edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim) subj_rep = edge_rep[:, 0] obj_rep = edge_rep[:, 1] prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]] if self.use_vision: if self.limit_vision: # exact value TBD prod_rep = torch.cat( (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1) else: prod_rep = prod_rep * vr if self.use_tanh: prod_rep = F.tanh(prod_rep) rel_dists = self.rel_compress(prod_rep) if self.use_bias: rel_dists = rel_dists + self.freq_bias.index_with_labels( torch.stack(( obj_preds[rel_inds[:, 1]], obj_preds[rel_inds[:, 2]], ), 1)) return rel_dists def get_freq_rel_dist(self, obj_preds, rel_inds): """ Baseline: relation model """ rel_dists = self.freq_bias.index_with_labels( torch.stack(( obj_preds[rel_inds[:, 1]], obj_preds[rel_inds[:, 2]], ), 1)) return rel_dists def get_simple_rel_dist(self, obj_preds, rel_inds): obj_preds_np = obj_preds.cpu().numpy() rel_inds_np = rel_inds.cpu().numpy() rel_dists_list = [] o1o2 = obj_preds_np[rel_inds_np][:, 1:] for o1, o2 in o1o2: rel_dists_list.append(self.prob_matrix[o1, o2]) assert len(rel_dists_list) == len(rel_inds) return Variable( torch.from_numpy(np.array(rel_dists_list)).cuda( obj_preds.get_device()) ) # there is no gradient for this type of code def forward( self, # pylint: disable=arguments-differ # inputs: PackedSequence, sequence_tensor, rel_inds, initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, labels=None, boxes_for_nms=None, vr=None): # get the relations for each object # numer = torch.arange(0, rel_inds.size(0)).long().cuda(rel_inds.get_device()) # objs_to_outrels = sequence_tensor.data.new(sequence_tensor.size(0), # rel_inds.size(0)).zero_() # objs_to_outrels.view(-1)[rel_inds[:, 1] * rel_inds.size(0) + numer] = 1 # objs_to_outrels = Variable(objs_to_outrels) # objs_to_inrels = sequence_tensor.data.new(sequence_tensor.size(0), rel_inds.size(0)).zero_() # objs_to_inrels.view(-1)[rel_inds[:, 2] * rel_inds.size(0) + numer] = 1 # # average the relations for each object, and add "non relation" to the one with on relation communication # # test8 / test10 need comments # objs_to_inrels = objs_to_inrels / (objs_to_inrels.sum(1) + 1e-8)[:, None] # objs_to_inrels = Variable(objs_to_inrels) batch_size = sequence_tensor.size(0) # We're just doing an LSTM decoder here so ignore states, etc if initial_state is None: previous_memory = Variable(sequence_tensor.data.new().resize_( batch_size, self.hidden_size).fill_(0)) previous_state = Variable(sequence_tensor.data.new().resize_( batch_size, self.hidden_size).fill_(0)) else: assert len(initial_state) == 2 previous_state = initial_state[0].squeeze(0) previous_memory = initial_state[1].squeeze(0) # 'start' previous_embed = self.obj_embed.weight[0, None].expand(batch_size, 100) # previous_comm_info = Variable(sequence_tensor.data.new() # .resize_(batch_size, 100).fill_(0)) if self.recurrent_dropout_probability > 0.0: dropout_mask = get_dropout_mask(self.recurrent_dropout_probability, previous_memory) else: dropout_mask = None # Only accumulating label predictions here, discarding everything else out_dists_list = [] out_commitments_list = [] end_ind = 0 for i in range(self.num_iter): # timestep_input = torch.cat((sequence_tensor, previous_embed, previous_comm_info), 1) timestep_input = torch.cat((sequence_tensor, previous_embed), 1) previous_state, previous_memory = self.lstm_equations( timestep_input, previous_state, previous_memory, dropout_mask=dropout_mask) pred_dist = self.out(previous_state) out_dists_list.append(pred_dist) # if self.training: # labels_to_embed = labels.clone() # # Whenever labels are 0 set input to be our max prediction # nonzero_pred = pred_dist[:, 1:].max(1)[1] + 1 # is_bg = (labels_to_embed.data == 0).nonzero() # if is_bg.dim() > 0: # labels_to_embed[is_bg.squeeze(1)] = nonzero_pred[is_bg.squeeze(1)] # out_commitments_list.append(labels_to_embed) # previous_embed = self.obj_embed(labels_to_embed+1) # else: # out_dist_sample = F.softmax(pred_dist, dim=1) # # if boxes_for_nms is not None: # # out_dist_sample[domains_allowed[i] == 0] = 0.0 # # Greedily take the max here amongst non-bgs # best_ind = out_dist_sample[:, 1:].max(1)[1] + 1 # # if boxes_for_nms is not None and i < boxes_for_nms.size(0): # # best_int = int(best_ind.data[0]) # # domains_allowed[i:, best_int] *= (1 - is_overlap[i, i:, best_int]) # out_commitments_list.append(best_ind) # previous_embed = self.obj_embed(best_ind+1) if self.training and (not self.sl_pretrain): import pdb pdb.set_trace() out_dist_sample = F.softmax(pred_dist, dim=1) sample_ind = out_dist_sample[:, 1:].multinomial( 1)[:, 0] + 1 # sampling at training stage out_commitments_list.append(sample_ind) previous_embed = self.obj_embed(sample_ind + 1) else: out_dist_sample = F.softmax(pred_dist, dim=1) # best_ind = out_dist_sample[:, 1:].max(1)[1] + 1 # debug best_ind = out_dist_sample.max(1)[ 1] ########################### out_commitments_list.append(best_ind) previous_embed = self.obj_embed(best_ind + 1) # calculate communicate information # rel_dists = self.get_rel_dist(best_ind, sequence_tensor, rel_inds, vr) # all_comm_info = rel_dists @ self.rel_embed.weight # obj_rel_weights = sequence_tensor @ torch.transpose(self.obj_rel_att.weight, 1, 0) @ torch.transpose(all_comm_info, 1, 0) # masked_objs_to_inrels = obj_rel_weights * objs_to_inrels # objs_to_inrels = masked_objs_to_inrels / (masked_objs_to_inrels.sum(1) + 1e-8)[:, None] # previous_comm_info = self.obj_in_lin(objs_to_inrels @ all_comm_info) out_dists = out_dists_list[-1] out_commitments = out_commitments_list[-1] # Do NMS here as a post-processing step """ if boxes_for_nms is not None and not self.training: is_overlap = nms_overlaps(boxes_for_nms.data).view( boxes_for_nms.size(0), boxes_for_nms.size(0), boxes_for_nms.size(1) ).cpu().numpy() >= self.nms_thresh # is_overlap[np.arange(boxes_for_nms.size(0)), np.arange(boxes_for_nms.size(0))] = False out_dists_sampled = F.softmax(out_dists).data.cpu().numpy() out_dists_sampled[:,0] = -1.0 # change 0.0 to 1.0 for the bug when the score for bg is almost 1. out_commitments = out_commitments.data.new(len(out_commitments)).fill_(0) for i in range(out_commitments.size(0)): box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(), out_dists_sampled.shape) out_commitments[int(box_ind)] = int(cls_ind) out_dists_sampled[is_overlap[box_ind,:,cls_ind], cls_ind] = -1.0 #0.0 out_dists_sampled[box_ind] = -1.0 # This way we won't re-sample out_commitments = Variable(out_commitments) """ # rel_dists = self.get_rel_dist(out_commitments, sequence_tensor, rel_inds, vr) # simple model # import pdb; pdb.set_trace() # rel_dists = self.get_freq_rel_dist(out_commitments, rel_inds) rel_dists = self.get_simple_rel_dist(out_commitments.data, rel_inds) return out_dists_list, out_commitments_list, None, \ out_dists, out_commitments, rel_dists