def __init__(self, vocab, image_size=(64, 64), embedding_dim=128, gconv_dim=128, gconv_hidden_dim=512, gconv_pooling='avg', gconv_num_layers=5, mask_size=32, mlp_normalization='none', appearance_normalization='', activation='', n_downsample_global=4, box_dim=128, use_attributes=False, box_noise_dim=64, mask_noise_dim=64, pool_size=100, rep_size=32): super(Model, self).__init__() self.vocab = vocab self.image_size = image_size self.use_attributes = use_attributes self.box_noise_dim = box_noise_dim self.mask_noise_dim = mask_noise_dim self.object_size = 64 #was 64 azade self.fake_pool = VectorPool(pool_size) #self.num_objs = len(vocab['object_to_idx']) #cm Azade self.num_objs = len(vocab['object_idx_to_name']) self.num_preds = len(vocab['pred_idx_to_name']) self.obj_embeddings = nn.Embedding(self.num_objs, embedding_dim) self.pred_embeddings = nn.Embedding(self.num_preds, embedding_dim) if use_attributes: attributes_dim = vocab['num_attributes'] else: attributes_dim = 0 if gconv_num_layers == 0: self.gconv = nn.Linear(embedding_dim, gconv_dim) elif gconv_num_layers > 0: gconv_kwargs = { 'input_dim': embedding_dim, 'attributes_dim': attributes_dim, 'output_dim': gconv_dim, 'hidden_dim': gconv_hidden_dim, 'pooling': gconv_pooling, 'mlp_normalization': mlp_normalization, } self.gconv = GraphTripleConv(**gconv_kwargs) self.gconv_net = None if gconv_num_layers > 1: gconv_kwargs = { 'input_dim': gconv_dim, 'hidden_dim': gconv_hidden_dim, 'pooling': gconv_pooling, 'num_layers': gconv_num_layers - 1, 'mlp_normalization': mlp_normalization, } self.gconv_net = GraphTripleConvNet(**gconv_kwargs) box_net_dim = 4 self.box_dim = box_dim box_net_layers = [self.box_dim, gconv_hidden_dim, box_net_dim] self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization) self.g_mask_dim = gconv_dim + mask_noise_dim self.mask_net = mask_net(self.g_mask_dim, mask_size) self.repr_input = self.g_mask_dim rep_size = rep_size rep_hidden_size = 64 repr_layers = [self.repr_input, rep_hidden_size, rep_size] self.repr_net = build_mlp(repr_layers, batch_norm=mlp_normalization) appearance_encoder_kwargs = { 'vocab': vocab, 'arch': 'C4-64-2,C4-128-2,C4-256-2', 'normalization': appearance_normalization, 'activation': activation, 'padding': 'valid', 'vecs_size': self.g_mask_dim } self.image_encoder = AppearanceEncoder(**appearance_encoder_kwargs) netG_input_nc = self.num_objs + rep_size output_nc = 3 ngf = 64 n_blocks_global = 9 norm = 'instance' self.layout_to_image = define_G(netG_input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm)
class Model(nn.Module): def __init__(self, vocab, image_size=(64, 64), embedding_dim=128, gconv_dim=128, gconv_hidden_dim=512, gconv_pooling='avg', gconv_num_layers=5, mask_size=32, mlp_normalization='none', appearance_normalization='', activation='', n_downsample_global=4, box_dim=128, use_attributes=False, box_noise_dim=64, mask_noise_dim=64, pool_size=100, rep_size=32): super(Model, self).__init__() self.vocab = vocab self.image_size = image_size self.use_attributes = use_attributes self.box_noise_dim = box_noise_dim self.mask_noise_dim = mask_noise_dim self.object_size = 64 #was 64 azade self.fake_pool = VectorPool(pool_size) #self.num_objs = len(vocab['object_to_idx']) #cm Azade self.num_objs = len(vocab['object_idx_to_name']) self.num_preds = len(vocab['pred_idx_to_name']) self.obj_embeddings = nn.Embedding(self.num_objs, embedding_dim) self.pred_embeddings = nn.Embedding(self.num_preds, embedding_dim) if use_attributes: attributes_dim = vocab['num_attributes'] else: attributes_dim = 0 if gconv_num_layers == 0: self.gconv = nn.Linear(embedding_dim, gconv_dim) elif gconv_num_layers > 0: gconv_kwargs = { 'input_dim': embedding_dim, 'attributes_dim': attributes_dim, 'output_dim': gconv_dim, 'hidden_dim': gconv_hidden_dim, 'pooling': gconv_pooling, 'mlp_normalization': mlp_normalization, } self.gconv = GraphTripleConv(**gconv_kwargs) self.gconv_net = None if gconv_num_layers > 1: gconv_kwargs = { 'input_dim': gconv_dim, 'hidden_dim': gconv_hidden_dim, 'pooling': gconv_pooling, 'num_layers': gconv_num_layers - 1, 'mlp_normalization': mlp_normalization, } self.gconv_net = GraphTripleConvNet(**gconv_kwargs) box_net_dim = 4 self.box_dim = box_dim box_net_layers = [self.box_dim, gconv_hidden_dim, box_net_dim] self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization) self.g_mask_dim = gconv_dim + mask_noise_dim self.mask_net = mask_net(self.g_mask_dim, mask_size) self.repr_input = self.g_mask_dim rep_size = rep_size rep_hidden_size = 64 repr_layers = [self.repr_input, rep_hidden_size, rep_size] self.repr_net = build_mlp(repr_layers, batch_norm=mlp_normalization) appearance_encoder_kwargs = { 'vocab': vocab, 'arch': 'C4-64-2,C4-128-2,C4-256-2', 'normalization': appearance_normalization, 'activation': activation, 'padding': 'valid', 'vecs_size': self.g_mask_dim } self.image_encoder = AppearanceEncoder(**appearance_encoder_kwargs) netG_input_nc = self.num_objs + rep_size output_nc = 3 ngf = 64 n_blocks_global = 9 norm = 'instance' self.layout_to_image = define_G(netG_input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm) def forward(self, gt_imgs, objs, triples, obj_to_img, boxes_gt=None, masks_gt=None, attributes=None, gt_train=False, test_mode=False, use_gt_box=False, features=None, drop_box_idx=None, drop_feat_idx=None, src_image=None): O, T = objs.size(0), triples.size(0) obj_vecs, pred_vecs = self.scene_graph_to_vectors( objs, triples, attributes) box_vecs, mask_vecs, scene_layout_vecs, wrong_layout_vecs = \ self.create_components_vecs(gt_imgs, boxes_gt, obj_to_img, objs, obj_vecs, features ,drop_box_idx=drop_box_idx, drop_feat_idx=drop_feat_idx, src_image=src_image) # Generate Boxes boxes_pred = self.box_net(box_vecs) vg = True # Generate Masks # Mask prediction network masks_pred = None if self.mask_net is not None: mask_scores = self.mask_net(mask_vecs.view(O, -1, 1, 1)) masks_pred = mask_scores.squeeze(1).sigmoid() H, W = self.image_size if vg: layout_boxes = boxes_pred if boxes_gt is None else boxes_gt #layout_masks = masks_pred if masks_gt is None else masks_gt masks_gt = masks_pred if masks_gt is None else masks_gt #masks_pred = layout_boxes if test_mode: boxes = boxes_gt if use_gt_box else boxes_pred masks = masks_gt if masks_gt is not None else masks_pred gt_layout = None pred_layout = masks_to_layout(scene_layout_vecs, boxes, masks, obj_to_img, H, W, test_mode=True) wrong_layout = None imgs_pred = self.layout_to_image(pred_layout) else: gt_layout = masks_to_layout(scene_layout_vecs, boxes_gt, masks_gt, obj_to_img, H, W, test_mode=False) pred_layout = masks_to_layout(scene_layout_vecs, boxes_gt, masks_pred, obj_to_img, H, W, test_mode=False) wrong_layout = masks_to_layout(wrong_layout_vecs, boxes_gt, masks_gt, obj_to_img, H, W, test_mode=False) imgs_pred = self.layout_to_image(gt_layout) return imgs_pred, boxes_pred, masks_pred, gt_layout, pred_layout, wrong_layout def scene_graph_to_vectors(self, objs, triples, attributes): s, p, o = triples.chunk(3, dim=1) s, p, o = [x.squeeze(1) for x in [s, p, o]] edges = torch.stack([s, o], dim=1) obj_vecs = self.obj_embeddings(objs) pred_vecs = self.pred_embeddings(p) if self.use_attributes: obj_vecs = torch.cat([obj_vecs, attributes], dim=1) if isinstance(self.gconv, nn.Linear): obj_vecs = self.gconv(obj_vecs) else: obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) if self.gconv_net is not None: obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) return obj_vecs, pred_vecs def create_components_vecs(self, imgs, boxes, obj_to_img, objs, obj_vecs, features, drop_box_idx=None, drop_feat_idx=None, jitter=(None, None, None), jitter_range=((-0.05, 0.05), (-0.05, 0.05)), src_image=None): O = objs.size(0) box_vecs = obj_vecs mask_vecs = obj_vecs layout_noise = torch.randn((1, self.mask_noise_dim), dtype=mask_vecs.dtype, device=mask_vecs.device) \ .repeat((O, 1)) \ .view(O, self.mask_noise_dim) mask_vecs = torch.cat([mask_vecs, layout_noise], dim=1) jitterFeat = False # create encoding if features is None: if jitterFeat: if obj_to_img is None: obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) imgbox_idx = -1 else: imgbox_idx = torch.zeros(src_image.size(0), dtype=torch.int64) for i in range(src_image.size(0)): imgbox_idx[i] = (obj_to_img == i).nonzero()[-1] add_jitter_bbox, add_jitter_layout, add_jitter_feats = jitter # unpack jitter_range_bbox, jitter_range_layout = jitter_range # Bounding boxes ---------------------------------------------------------- box_ones = torch.ones([O, 1], dtype=boxes.dtype, device=boxes.device) if drop_box_idx is not None: box_keep = drop_box_idx else: # drop random box(es) box_keep = F.dropout(box_ones, self.p, True, False) * (1 - self.p) # image obj cannot be dropped box_keep[imgbox_idx, :] = 1 if add_jitter_bbox is not None: boxes_gt = jitter_bbox( boxes, p=add_jitter_bbox, noise_range=jitter_range_bbox, eval_mode=True) # uses default settings boxes_prior = boxes * box_keep # Object features ---------------------------------------------------------- if drop_feat_idx is not None: feats_keep = drop_feat_idx else: feats_keep = F.dropout(box_ones, self.p, True, False) * (1 - self.p) # print(feats_keep) # image obj feats should be dropped feats_keep[ imgbox_idx, :] = 1 # should they be dropped or should they not? obj_crop, src_image, generated = get_cropped_objs( src_image, boxes, obj_to_img, box_keep, feats_keep, True) crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size) #print(crops.shape, obj_crop.shape) obj_repr = self.repr_net(self.image_encoder(crops)) else: # Only in inference time obj_repr = self.repr_net(mask_vecs) for ind, feature in enumerate(features): if feature is not None: obj_repr[ind, :] = feature # create one-hot vector for label map #obj_repr = obj_repr[:,:-1] one_hot_size = (O, self.num_objs) one_hot_obj = torch.zeros(one_hot_size, dtype=obj_repr.dtype, device=obj_repr.device) one_hot_obj = one_hot_obj.scatter_(1, objs.view(-1, 1).long(), 1.0) layout_vecs = torch.cat([one_hot_obj, obj_repr], dim=1) wrong_objs_rep = self.fake_pool.query(objs, obj_repr) wrong_layout_vecs = torch.cat([one_hot_obj, wrong_objs_rep], dim=1) return box_vecs, mask_vecs, layout_vecs, wrong_layout_vecs def encode_scene_graphs(self, scene_graphs, rand=False): """ Encode one or more scene graphs using this model's vocabulary. Inputs to this method are scene graphs represented as dictionaries like the following: { "objects": ["cat", "dog", "sky"], "relationships": [ [0, "next to", 1], [0, "beneath", 2], [2, "above", 1], ] } This scene graph has three relationshps: cat next to dog, cat beneath sky, and sky above dog. Inputs: - scene_graphs: A dictionary giving a single scene graph, or a list of dictionaries giving a sequence of scene graphs. Returns a tuple of LongTensors (objs, triples, obj_to_img) that have the same semantics as self.forward. The returned LongTensors will be on the same device as the model parameters. """ if isinstance(scene_graphs, dict): # We just got a single scene graph, so promote it to a list scene_graphs = [scene_graphs] device = next(self.parameters()).device objs, triples, obj_to_img = [], [], [] all_attributes = [] all_features = [] obj_offset = 0 for i, sg in enumerate(scene_graphs): attributes = torch.zeros([len(sg['objects']) + 1, 25 + 10], dtype=torch.float, device=device) # Insert dummy __image__ object and __in_image__ relationships sg['objects'].append('__image__') sg['features'].append(sg['image_id']) image_idx = len(sg['objects']) - 1 for j in range(image_idx): sg['relationships'].append([j, '__in_image__', image_idx]) for obj in sg['objects']: obj_idx = self.vocab['object_name_to_idx'].get(obj, None) #self.vocab['object_to_idx'][str(self.vocab['object_name_to_idx'][obj])] #cm Azade if obj_idx is None: raise ValueError('Object "%s" not in vocab' % obj) objs.append(obj_idx) obj_to_img.append(i) if self.features is not None: for obj_name, feat_num in zip(objs, sg['features']): if feat_num == -1: feat = self.features_one[obj_name][0] else: feat = self.features[obj_name][min(feat_num, 99), :] feat = torch.from_numpy(feat).type( torch.float32).to(device) all_features.append(feat) for s, p, o in sg['relationships']: pred_idx = self.vocab['pred_name_to_idx'].get(p, None) if pred_idx is None: raise ValueError('Relationship "%s" not in vocab' % p) triples.append([s + obj_offset, pred_idx, o + obj_offset]) for i, size_attr in enumerate(sg['attributes']['size']): attributes[i, size_attr] = 1 # in image size attributes[-1, 9] = 1 for i, location_attr in enumerate(sg['attributes']['location']): attributes[i, location_attr + 10] = 1 # in image location attributes[-1, 12 + 10] = 1 obj_offset += len(sg['objects']) all_attributes.append(attributes) objs = torch.tensor(objs, dtype=torch.int64, device=device) triples = torch.tensor(triples, dtype=torch.int64, device=device) obj_to_img = torch.tensor(obj_to_img, dtype=torch.int64, device=device) attributes = torch.cat(all_attributes) features = all_features return objs, triples, obj_to_img, attributes, features def forward_json(self, scene_graphs): """ Convenience method that combines encode_scene_graphs and forward. """ objs, triples, obj_to_img, attributes, features = self.encode_scene_graphs( scene_graphs) return self.forward(None, objs, triples, obj_to_img, attributes=attributes, gt_train=False, test_mode=True, use_gt_box=False, features=features)