def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, use_resnet=False, thresh=0.01, use_proposals=False, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True, sl_pretrain=False, eval_rel_objs=False, num_iter=-1, reduce_input=False, post_nms_thresh=0.5): super(RelModelAlign, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode # self.pooling_size = 7 self.pooling_size = conf.pooling_size self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.num_iter = num_iter self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.sl_pretrain = sl_pretrain self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = DynamicFilterContext(self.classes, self.rel_classes, mode=self.mode, use_vision=self.use_vision, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, pooling_dim=self.pooling_dim, pooling_size=self.pooling_size, dropout_rate=rec_dropout, use_bias=self.use_bias, use_tanh=self.use_tanh, limit_vision=self.limit_vision, sl_pretrain = self.sl_pretrain, num_iter=self.num_iter, use_resnet=use_resnet, reduce_input=reduce_input, post_nms_thresh=post_nms_thresh)
num_gpus=detector_num_gpus) else: train, val, _ = VG.splits(num_val_im=conf.val_size, filter_non_overlap=False, filter_empty_rels=False, use_proposals=conf.use_proposals) train_loader, val_loader = VGDataLoader.splits( train, val, batch_size=conf.batch_size, num_workers=conf.num_workers, num_gpus=detector_num_gpus) detector = ObjectDetector( classes=train.ind_to_classes, num_gpus=detector_num_gpus, mode='refinerels' if not conf.use_proposals else 'proposals', use_resnet=conf.use_resnet) # print(detector) # os._exit(0) detector.cuda() # Note: if you're doing the stanford setup, you'll need to change this to freeze the lower layers if conf.use_proposals: for n, param in detector.named_parameters(): if n.startswith('features'): param.requires_grad = False start_epoch = -1 if conf.ckpt is not None: ckpt = torch.load(conf.ckpt)
from config import ModelConfig, FG_FRACTION, RPN_FG_FRACTION, IM_SCALE, BOX_SCALE from torch.nn import functional as F from lib.fpn.box_utils import bbox_loss import torch.backends.cudnn as cudnn from lib.pytorch_misc import optimistic_restore, clip_grad_norm from torch.optim.lr_scheduler import ReduceLROnPlateau os.environ['CUDA_VISIBLE_DEVICES'] = '3' torch.cuda.set_device(0) cudnn.benchmark = True conf = ModelConfig() train, val, _ = VG.splits(num_val_im=conf.val_size, filter_non_overlap=False, filter_empty_rels=False, use_proposals=False) train_loader, val_loader = VGDataLoader.splits(train, val, batch_size=1, num_workers=1, num_gpus=1) detector = ObjectDetector(classes=train.ind_to_classes, num_gpus=1, mode='refinerels', use_resnet=False) detector.cuda() ckpt = torch.load(conf.ckpt) optimistic_restore(detector, ckpt['state_dict']) detector.eval() for batch in train_loader: results = detector[batch]
def __init__(self, classes, rel_classes, graph_path, emb_path, mode='sgdet', num_gpus=1, require_overlap_det=True, pooling_dim=4096, use_resnet=False, thresh=0.01, use_proposals=False, ggnn_rel_time_step_num=3, ggnn_rel_hidden_dim=512, ggnn_rel_output_dim=512, use_knowledge=True, use_embedding=True, refine_obj_cls=False, rel_counts_path=None, class_volume=1.0, top_k_to_keep=5, normalize_messages=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param require_overlap_det: Whether two objects must intersect """ super(KERN, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.obj_dim = 2048 if use_resnet else 4096 self.rel_dim = self.obj_dim self.pooling_dim = pooling_dim self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64 ) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier self.ggnn_rel_reason = GGNNRelReason(mode=self.mode, num_obj_cls=len(self.classes), num_rel_cls=len(rel_classes), obj_dim=self.obj_dim, rel_dim=self.rel_dim, time_step_num=ggnn_rel_time_step_num, hidden_dim=ggnn_rel_hidden_dim, output_dim=ggnn_rel_output_dim, emb_path=emb_path, graph_path=graph_path, refine_obj_cls=refine_obj_cls, use_knowledge=use_knowledge, use_embedding=use_embedding, top_k_to_keep=top_k_to_keep, normalize_messages=normalize_messages ) if rel_counts_path is not None: with open(rel_counts_path, 'rb') as fin: rel_counts = pickle.load(fin) beta = (class_volume - 1.0) / class_volume self.rel_class_weights = (1.0 - beta) / (1 - (beta ** rel_counts)) self.rel_class_weights *= float(self.num_rels) / np.sum(self.rel_class_weights) else: self.rel_class_weights = np.ones((self.num_rels,)) self.rel_class_weights = Variable(torch.from_numpy(self.rel_class_weights).float().cuda(), requires_grad=False)
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, gnn=True, reachability=False, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.reachability = reachability self.gnn = gnn self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.global_embedding = EmbeddingImagenet(4096) self.global_logist = nn.Linear(4096, 151, bias=True) # CosineLinear(4096,150)# self.global_logist.weight = torch.nn.init.xavier_normal( self.global_logist.weight, gain=1.0) self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim + 256) self.meta_classify = MetaEmbedding_Classifier( feat_dim=self.pooling_dim + 256, num_classes=self.num_rels) # self.global_rel_logist = nn.Linear(4096, 50 , bias=True) # self.global_rel_logist.weight = torch.nn.init.xavier_normal(self.global_rel_logist.weight, gain=1.0) # self.global_logist = CosineLinear(4096,150) self.global_sub_additive = nn.Linear(4096, 1, bias=True) self.global_obj_additive = nn.Linear(4096, 1, bias=True) self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) self.edge_coordinate_embedding = nn.Sequential(*[ nn.BatchNorm1d(5, momentum=BATCHNORM_MOMENTUM / 10.0), nn.Linear(5, 256), nn.ReLU(inplace=True), nn.Dropout(0.1), ]) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(4096 + 256, 51, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) self.node_transform = nn.Linear(4096, 256, bias=True) self.edge_transform = nn.Linear(4096, 256, bias=True) # self.rel_compress = CosineLinear(self.pooling_dim+256, self.num_rels) # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() if self.gnn: self.graph_network_node = GraphNetwork(4096) self.graph_network_edge = GraphNetwork() if self.training: self.graph_network_node.train() self.graph_network_edge.train() else: self.graph_network_node.eval() self.graph_network_edge.eval() self.edge_sim_network = nn.Linear(4096, 1, bias=True) self.metric_net = MetricLearning()
def __init__(self, args, ind_to_classes): super(neural_motifs_sg2im_model, self).__init__() self.args = args # define and initial detector self.detector = ObjectDetector( classes=ind_to_classes, num_gpus=args.num_gpus, mode='refinerels' if not args.use_proposals else 'proposals', use_resnet=args.use_resnet) if args.ckpt is not None: ckpt = torch.load(args.ckpt) optimistic_restore(self.detector, ckpt['state_dict']) self.detector.eval() # define and initial generator, image_discriminator, obj_discriminator, # and corresponding optimizer vocab = { 'object_idx_to_name': ind_to_classes, } self.model, model_kwargs = build_model(args) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate) self.obj_discriminator, d_obj_kwargs = build_obj_discriminator( args, vocab) self.img_discriminator, d_img_kwargs = build_img_discriminator(args) if self.obj_discriminator is not None: self.obj_discriminator.train() self.optimizer_d_obj = torch.optim.Adam( self.obj_discriminator.parameters(), lr=args.learning_rate) if self.img_discriminator is not None: self.img_discriminator.train() self.optimizer_d_img = torch.optim.Adam( self.img_discriminator.parameters(), lr=args.learning_rate) restore_path = None if args.restore_from_checkpoint: restore_path = '%s_with_model.pt' % args.checkpoint_name restore_path = os.path.join(args.output_dir, restore_path) if restore_path is not None and os.path.isfile(restore_path): print('Restoring from checkpoint:') print(restore_path) checkpoint = torch.load(restore_path) self.model.load_state_dict(checkpoint['model_state']) self.optimizer.load_state_dict(checkpoint['optim_state']) if self.obj_discriminator is not None: self.obj_discriminator.load_state_dict( checkpoint['d_obj_state']) self.optimizer_d_obj.load_state_dict( checkpoint['d_obj_optim_state']) if self.img_discriminator is not None: self.img_discriminator.load_state_dict( checkpoint['d_img_state']) self.optimizer_d_img.load_state_dict( checkpoint['d_img_optim_state']) t = checkpoint['counters']['t'] if 0 <= args.eval_mode_after <= t: self.model.eval() else: self.model.train() epoch = checkpoint['counters']['epoch'] else: t, epoch = 0, 0 checkpoint = { 'vocab': vocab, 'model_kwargs': model_kwargs, 'd_obj_kwargs': d_obj_kwargs, 'd_img_kwargs': d_img_kwargs, 'losses_ts': [], 'losses': defaultdict(list), 'd_losses': defaultdict(list), 'checkpoint_ts': [], 'train_batch_data': [], 'train_samples': [], 'train_iou': [], 'val_batch_data': [], 'val_samples': [], 'val_losses': defaultdict(list), 'val_iou': [], 'norm_d': [], 'norm_g': [], 'counters': { 't': None, 'epoch': None, }, 'model_state': None, 'model_best_state': None, 'optim_state': None, 'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None, 'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None, 'best_t': [], } self.t, self.epoch, self.checkpoint = t, epoch, checkpoint
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, obj_dim=2048, pooling_dim=4096, nl_obj=1, nl_edge=2, use_resnet=True, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True, spatial_dim=128, graph_constrain=True, mp_iter_num=1): """ Args: mp_iter_num: integer, number of message passing iteration """ super(FckModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = obj_dim self.pooling_dim = 2048 if use_resnet else 4096 self.spatial_dim = spatial_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.graph_cons = graph_constrain self.mp_iter_num = mp_iter_num classes_word_vec = obj_edge_vectors(self.classes, wv_dim=embed_dim) self.classes_word_embedding = nn.Embedding(self.num_classes, embed_dim) self.classes_word_embedding.weight.data = classes_word_vec.clone() self.classes_word_embedding.weight.requires_grad = False # the last one is dirty bit self.rel_mem = nn.Embedding(self.num_rels, self.obj_dim + 1) self.rel_mem.weight.data[:, -1] = 0 if mode == 'sgdet': if use_proposals: obj_detector_mode = 'proposals' else: obj_detector_mode = 'refinerels' else: obj_detector_mode = 'gtbox' self.detector = ObjectDetector( classes=classes, mode=obj_detector_mode, use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512, use_feats=False) self.spatial_fc = nn.Sequential(*[ nn.Linear(4, spatial_dim), nn.BatchNorm1d(spatial_dim, momentum=BATCHNORM_MOMENTUM / 10.), nn.ReLU(inplace=True) ]) self.word_fc = nn.Sequential(*[ nn.Linear(2 * embed_dim, hidden_dim), nn.BatchNorm1d(hidden_dim, momentum=BATCHNORM_MOMENTUM / 10.), nn.ReLU(inplace=True) ]) # union box feats feats_dim = obj_dim + spatial_dim + hidden_dim self.relpn_fc = nn.Linear(feats_dim, 2) self.relcnn_fc1 = nn.Sequential( *[nn.Linear(feats_dim, feats_dim), nn.ReLU(inplace=True)]) self.box_mp_fc = nn.Sequential(*[ nn.Linear(obj_dim, obj_dim), ]) self.sub_rel_mp_fc = nn.Sequential(*[nn.Linear(feats_dim, obj_dim)]) self.obj_rel_mp_fc = nn.Sequential(*[ nn.Linear(feats_dim, obj_dim), ]) self.mp_atten_fc = nn.Sequential(*[ nn.Linear(feats_dim + obj_dim, obj_dim), nn.ReLU(inplace=True), nn.Linear(obj_dim, 1) ]) self.cls_fc = nn.Linear(obj_dim, self.num_classes) self.relcnn_fc2 = nn.Linear( feats_dim, self.num_rels if self.graph_cons else 2 * self.num_rels) if use_resnet: #deprecate self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ load_vgg( use_dropout=False, use_relu=False, use_linear=self.obj_dim == 4096, pretrained=False, ).classifier, nn.Linear(self.pooling_dim, self.obj_dim) ] self.roi_fmap = nn.Sequential(*roi_fmap)
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision=limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' # print('REL MODEL CONSTRUCTOR: 1') self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) # print('REL MODEL CONSTRUCTOR: 2') self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) # print('REL MODEL CONSTRUCTOR: 3') if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier # print('REL MODEL CONSTRUCTOR: 4') ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() # print('REL MODEL CONSTRUCTOR: 5') if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias()
if conf.test: print("test data!") val = test train_loader, val_loader = VGDataLoader.splits(train, val, mode='rel', batch_size=conf.batch_size, num_workers=conf.num_workers, num_gpus=conf.num_gpus) fg_matrix, bg_matrix = get_counts(train_data=train, must_overlap=MUST_OVERLAP) detector = ObjectDetector( classes=train.ind_to_classes, num_gpus=conf.num_gpus, mode='rpntrain' if not conf.use_proposals else 'proposals', use_resnet=conf.use_resnet, nms_filter_duplicates=True, thresh=0.01) detector.eval() detector.cuda() classifier = ObjectDetector(classes=train.ind_to_classes, num_gpus=conf.num_gpus, mode='gtbox', use_resnet=conf.use_resnet, nms_filter_duplicates=True, thresh=0.01) classifier.eval() classifier.cuda()
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=False, embed_dim=200, hidden_dim=256, obj_dim=2048, pooling_dim=4096, nl_obj=1, nl_edge=2, use_resnet=True, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True, spatial_dim=128, mp_iter_num=1, trim_graph=True): """ Args: mp_iter_num: integer, number of message passing iteration trim_graph: boolean, trim graph in rel pn """ super(FckModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = obj_dim self.pooling_dim = 2048 if use_resnet else 4096 self.spatial_dim = spatial_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.mp_iter_num = mp_iter_num self.trim_graph = trim_graph classes_word_vec = obj_edge_vectors(self.classes, wv_dim=embed_dim) self.classes_word_embedding = nn.Embedding(self.num_classes, embed_dim) self.classes_word_embedding.weight.data = classes_word_vec.clone() self.classes_word_embedding.weight.requires_grad = False #fg_matrix, bg_matrix = get_counts() #rel_obj_distribution = fg_matrix / (fg_matrix.sum(2)[:, :, None] + 1e-5) #rel_obj_distribution = torch.FloatTensor(rel_obj_distribution) #rel_obj_distribution = rel_obj_distribution.view(-1, self.num_rels) # #self.rel_obj_distribution = nn.Embedding(rel_obj_distribution.size(0), self.num_rels) ## (#obj_class * #obj_class, #rel_class) #self.rel_obj_distribution.weight.data = rel_obj_distribution if mode == 'sgdet': if use_proposals: obj_detector_mode = 'proposals' else: obj_detector_mode = 'refinerels' else: obj_detector_mode = 'gtbox' self.detector = ObjectDetector( classes=classes, mode=obj_detector_mode, use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512, use_feats=False) self.spatial_fc = nn.Sequential(*[ nn.Linear(4, spatial_dim), nn.BatchNorm1d(spatial_dim, momentum=BATCHNORM_MOMENTUM / 10.), nn.ReLU(inplace=True) ]) self.word_fc = nn.Sequential(*[ nn.Linear(2 * embed_dim, hidden_dim), nn.BatchNorm1d(hidden_dim, momentum=BATCHNORM_MOMENTUM / 10.), nn.ReLU(inplace=True) ]) # union box feats feats_dim = obj_dim + spatial_dim + hidden_dim self.relpn_fc = nn.Linear(feats_dim, 2) self.relcnn_fc1 = nn.Sequential( *[nn.Linear(feats_dim, feats_dim), nn.ReLU(inplace=True)]) # v2 model--------- self.box_mp_fc = nn.Sequential(*[ nn.Linear(obj_dim, obj_dim), ]) self.sub_rel_mp_fc = nn.Sequential(*[nn.Linear(feats_dim, obj_dim)]) self.obj_rel_mp_fc = nn.Sequential(*[ nn.Linear(feats_dim, obj_dim), ]) self.mp_atten_fc = nn.Sequential(*[ nn.Linear(feats_dim + obj_dim, obj_dim), nn.ReLU(inplace=True), nn.Linear(obj_dim, 1) ]) # v2 model---------- self.cls_fc = nn.Linear(obj_dim, self.num_classes) self.relcnn_fc2 = nn.Linear(feats_dim, self.num_rels) # v3 model ----------- self.mem_module = MemoryRNN(classes=classes, rel_classes=rel_classes, inputs_dim=feats_dim, hidden_dim=hidden_dim, recurrent_dropout_probability=.0) # v3 model ----------- if use_resnet: # deprecate self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ load_vgg( use_dropout=False, use_relu=False, use_linear=self.obj_dim == 4096, pretrained=False, ).classifier, nn.Linear(self.pooling_dim, self.obj_dim) ] self.roi_fmap = nn.Sequential(*roi_fmap)
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ Args: classes: list, list of 151 object class names(including background) rel_classes: list, list of 51 predicate names( including background(norelationship)) mode: string, 'sgdet', 'predcls' or 'sgcls' num_gpus: integer, number of GPUs to use use_vision: boolean, whether to use vision in the final product require_overlap_det: boolean, whether two object must intersect embed_dim: integer, number of dimension for all embeddings hidden_dim: integer, hidden size of LSTM pooling_dim: integer, outputsize of vgg fc layer nl_obj: integer, number of object context layer, 2 in paper nl_edge: integer, number of edge context layer, 4 in paper use_resnet: integer, use resnet for backbone order: string, value must be in ('size', 'confidence', 'random', 'leftright'), order of RoIs thresh: float, threshold for scores of boxes if score of box smaller than thresh, then it will be abandoned use_proposals: boolean, whether to use proposals pass_in_obj_feats_to_decoder: boolean, whether to pass object features to decoder RNN pass_in_obj_feats_to_edge: boolean, whether to pass object features to edge context RNN rec_dropout: float, dropout rate in RNN use_bias: boolean, use_tanh: boolean, limit_vision: boolean, """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() if nl_edge == 0: self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias()
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, model_path='', reachability=False, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, init_center=False, limit_vision=True): super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.init_center = init_center self.pooling_size = 7 self.model_path = model_path self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.centroids = None self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.global_embedding = EmbeddingImagenet(4096) self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2) self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim) self.meta_classify = MetaEmbedding_Classifier( feat_dim=self.pooling_dim, num_classes=self.num_rels) self.disc_center_g = DiscCentroidsLoss(self.num_classes, self.pooling_dim) self.meta_classify_g = MetaEmbedding_Classifier( feat_dim=self.pooling_dim, num_classes=self.num_classes) self.global_sub_additive = nn.Linear(4096, 1, bias=True) self.global_obj_additive = nn.Linear(4096, 1, bias=True) # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1. # (Half contribution comes from LSTM, half from embedding. # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10. self.post_lstm.weight.data.normal_( 0, 10.0 * math.sqrt(1.0 / self.hidden_dim)) self.post_lstm.bias.data.zero_() self.global_logist = nn.Linear(self.pooling_dim, self.num_classes, bias=True) # CosineLinear(4096,150)# self.global_logist.weight = torch.nn.init.xavier_normal( self.global_logist.weight, gain=1.0) self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2) self.post_emb.weight.data.normal_(0, math.sqrt(1.0)) self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) if self.use_bias: self.freq_bias = FrequencyBias() self.class_num = torch.zeros(len(self.classes)) self.centroids = torch.zeros(len(self.classes), self.pooling_dim).cuda()
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, require_overlap_det=True, embed_dim=200, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) """ super(NODIS, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.obj_dim = 2048 if use_resnet else 4096 self.order = 'random' self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.context = O_NODE(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, obj_dim=self.obj_dim, order=order) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: self.roi_fmap_obj = load_vgg(pretrained=False).classifier self.roi_avg_pool = nn.AvgPool2d(kernel_size=7, stride=0) ################################### embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim) self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim) self.obj_embed.weight.data = embed_vecs.clone() self.obj_embed2 = nn.Embedding(self.num_classes, self.embed_dim) self.obj_embed2.weight.data = embed_vecs.clone() self.lstm_visual = nn.LSTM(input_size=1536, hidden_size=512) self.lstm_semantic = nn.LSTM(input_size=400, hidden_size=512) self.odeBlock = odeBlock(odeFunc1(bidirectional=True)) self.fc_predicate = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(inplace=False), nn.Linear(512, 51), nn.ReLU(inplace=False))
val, batch_size=conf.batch_size, num_workers=conf.num_workers, num_gpus=conf.num_gpus) else: train, val = HICO.splits(num_val_im=conf.val_size) train_loader, val_loader = HICODataLoader.splits( train, val, batch_size=conf.batch_size, num_workers=conf.num_workers, num_gpus=conf.num_gpus) detector = ObjectDetector( classes=train.ind_to_classes, num_gpus=conf.num_gpus, mode='rpntrain' if not conf.use_proposals else 'proposals', use_resnet=conf.use_resnet) detector.cuda() start_epoch = -1 if conf.ckpt is not None: ckpt = torch.load(conf.ckpt) if optimistic_restore(detector, ckpt['state_dict']): start_epoch = ckpt['epoch'] def val_epoch(): detector.eval() # all_boxes is a list of length number-of-classes. # Each list element is a list of length number-of-images.
num_gpus=conf.num_gpus) else: train, val, _ = VG.splits(num_val_im=conf.val_size, filter_non_overlap=False, filter_empty_rels=False, use_proposals=conf.use_proposals) train_loader, val_loader = VGDataLoader.splits( train, val, batch_size=conf.batch_size, num_workers=conf.num_workers, num_gpus=conf.num_gpus) detector = ObjectDetector( classes=train.ind_to_classes, num_gpus=conf.num_gpus, mode='rpntrain' if not conf.use_proposals else 'proposals', use_resnet=conf.use_resnet) detector.cuda() # Note: if you're doing the stanford setup, you'll need to change this to freeze the lower layers if conf.use_proposals: for n, param in detector.named_parameters(): if n.startswith('features'): param.requires_grad = False optimizer = optim.SGD([p for p in detector.parameters() if p.requires_grad], weight_decay=conf.l2, lr=conf.lr * conf.num_gpus * conf.batch_size, momentum=0.9) scheduler = ReduceLROnPlateau(optimizer,
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, require_overlap_det=True, pooling_dim=4096, use_resnet=False, thresh=0.01, use_proposals=False, use_ggnn_obj=False, ggnn_obj_time_step_num=3, ggnn_obj_hidden_dim=512, ggnn_obj_output_dim=512, use_ggnn_rel=False, ggnn_rel_time_step_num=3, ggnn_rel_hidden_dim=512, ggnn_rel_output_dim=512, use_obj_knowledge=True, use_rel_knowledge=True, obj_knowledge='', rel_knowledge=''): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param require_overlap_det: Whether two objects must intersect """ super(KERN, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.obj_dim = 2048 if use_resnet else 4096 self.rel_dim = self.obj_dim self.pooling_dim = pooling_dim self.use_ggnn_obj = use_ggnn_obj self.use_ggnn_rel = use_ggnn_rel self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier if self.use_ggnn_obj: self.ggnn_obj_reason = GGNNObjReason( mode=self.mode, num_obj_cls=len(self.classes), obj_dim=self.obj_dim, time_step_num=ggnn_obj_time_step_num, hidden_dim=ggnn_obj_hidden_dim, output_dim=ggnn_obj_output_dim, use_knowledge=use_obj_knowledge, knowledge_matrix=obj_knowledge) if self.use_ggnn_rel: self.ggnn_rel_reason = GGNNRelReason( mode=self.mode, num_obj_cls=len(self.classes), num_rel_cls=len(rel_classes), obj_dim=self.obj_dim, rel_dim=self.rel_dim, time_step_num=ggnn_rel_time_step_num, hidden_dim=ggnn_rel_hidden_dim, output_dim=ggnn_obj_output_dim, use_knowledge=use_rel_knowledge, knowledge_matrix=rel_knowledge) else: self.vr_fc_cls = VRFC(self.mode, self.rel_dim, len(self.classes), len(self.rel_classes))
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True, embed_dim=200, hidden_dim=256, pooling_dim=2048, nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False, pass_in_obj_feats_to_decoder=True, pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True, limit_vision=True): """ :param classes: Object classes :param rel_classes: Relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: Whether to use vision in the final product :param require_overlap_det: Whether two objects must intersect :param embed_dim: Dimension for all embeddings :param hidden_dim: LSTM hidden size :param obj_dim: """ super(RelModel, self).__init__() self.classes = classes self.rel_classes = rel_classes self.num_gpus = num_gpus assert mode in MODES self.mode = mode self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.pooling_dim = pooling_dim self.use_bias = use_bias self.use_vision = use_vision self.use_tanh = use_tanh self.limit_vision = limit_vision self.require_overlap = require_overlap_det and self.mode == 'sgdet' self.hook_for_grad = False self.gradients = [] self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.ort_embedding = torch.autograd.Variable( get_ort_embeds(self.num_classes, 200).cuda()) embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim) self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim) self.obj_embed.weight.data = embed_vecs.clone() # This probably doesn't help it much self.pos_embed = nn.Sequential(*[ nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0), nn.Linear(4, 128), nn.ReLU(inplace=True), nn.Dropout(0.1), ]) self.context = LinearizedContext( self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout, order=order, pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder, pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge) # Image Feats (You'll have to disable if you want to turn off the features from here) self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) self.merge_obj_feats = nn.Sequential( nn.Linear(self.obj_dim + self.embed_dim + 128, self.hidden_dim), nn.ReLU()) # self.trans = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim//4), # LayerNorm(self.hidden_dim//4), nn.ReLU(), # nn.Linear(self.hidden_dim//4, self.hidden_dim)) self.get_phr_feats = nn.Linear(self.pooling_dim, self.hidden_dim) self.embeddings4lstm = nn.Embedding(self.num_classes, self.embed_dim) self.lstm = nn.LSTM(input_size=self.hidden_dim + self.embed_dim, hidden_size=self.hidden_dim, num_layers=1) self.obj_mps1 = Message_Passing4OBJ(self.hidden_dim) # self.obj_mps2 = Message_Passing4OBJ(self.hidden_dim) self.get_boxes_encode = Boxes_Encode(64) if use_resnet: self.roi_fmap = nn.Sequential( resnet_l4(relu_end=False), nn.AvgPool2d(self.pooling_size), Flattener(), ) else: roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier, ] if pooling_dim != 4096: roi_fmap.append(nn.Linear(4096, pooling_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) self.roi_fmap_obj = load_vgg(pretrained=False).classifier ################################### # self.obj_classify_head = nn.Linear(self.pooling_dim, self.num_classes) # self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim//2) # self.post_emb_s.weight = torch.nn.init.xavier_normal(self.post_emb_s.weight, gain=1.0) # self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim//2) # self.post_emb_o.weight = torch.nn.init.xavier_normal(self.post_emb_o.weight, gain=1.0) # self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim//2) # self.merge_obj_high.weight = torch.nn.init.xavier_normal(self.merge_obj_high.weight, gain=1.0) # self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim//2) # self.merge_obj_low.weight = torch.nn.init.xavier_normal(self.merge_obj_low.weight, gain=1.0) # self.rel_compress = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True) # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0) # self.freq_gate = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True) # self.freq_gate.weight = torch.nn.init.xavier_normal(self.freq_gate.weight, gain=1.0) self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim) self.post_emb_s.weight = torch.nn.init.xavier_normal( self.post_emb_s.weight, gain=1.0) self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim) self.post_emb_o.weight = torch.nn.init.xavier_normal( self.post_emb_o.weight, gain=1.0) self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim) self.merge_obj_high.weight = torch.nn.init.xavier_normal( self.merge_obj_high.weight, gain=1.0) self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim) self.merge_obj_low.weight = torch.nn.init.xavier_normal( self.merge_obj_low.weight, gain=1.0) self.rel_compress = nn.Linear(self.pooling_dim + 64, self.num_rels, bias=True) self.rel_compress.weight = torch.nn.init.xavier_normal( self.rel_compress.weight, gain=1.0) self.freq_gate = nn.Linear(self.pooling_dim + 64, self.num_rels, bias=True) self.freq_gate.weight = torch.nn.init.xavier_normal( self.freq_gate.weight, gain=1.0) # self.ranking_module = nn.Sequential(nn.Linear(self.pooling_dim + 64, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, 1)) if self.use_bias: self.freq_bias = FrequencyBias()
def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=False, require_overlap_det=True, embed_dim=200, hidden_dim=4096, use_resnet=False, thresh=0.01, use_proposals=False, use_bias=True, limit_vision=True, depth_model=None, pretrained_depth=False, active_features=None, frozen_features=None, use_embed=False, **kwargs): """ :param classes: object classes :param rel_classes: relationship classes. None if were not using rel mode :param mode: (sgcls, predcls, or sgdet) :param num_gpus: how many GPUS 2 use :param use_vision: enable the contribution of union of bounding boxes :param require_overlap_det: whether two objects must intersect :param embed_dim: word2vec embeddings dimension :param hidden_dim: dimension of the fusion hidden layer :param use_resnet: use resnet as faster-rcnn's backbone :param thresh: faster-rcnn related threshold (Threshold for calling it a good box) :param use_proposals: whether to use region proposal candidates :param use_bias: enable frequency bias :param limit_vision: use truncated version of UoBB features :param depth_model: provided architecture for depth feature extraction :param pretrained_depth: whether the depth feature extractor should be initialized with ImageNet weights :param active_features: what set of features should be enabled (e.g. 'vdl' : visual, depth, and location features) :param frozen_features: what set of features should be frozen (e.g. 'd' : depth) :param use_embed: use word2vec embeddings """ RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus, require_overlap_det, active_features, frozen_features) self.pooling_size = 7 self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.obj_dim = 2048 if use_resnet else 4096 self.use_vision = use_vision self.use_bias = use_bias self.limit_vision = limit_vision # -- Store depth related parameters assert depth_model in DEPTH_MODELS self.depth_model = depth_model self.pretrained_depth = pretrained_depth self.depth_pooling_dim = DEPTH_DIMS[self.depth_model] self.use_embed = use_embed self.detector = nn.Module() features_size = 0 # -- Check whether ResNet is selected as faster-rcnn's backbone if use_resnet: raise ValueError( "The current model does not support ResNet as the Faster-RCNN's backbone." ) """ *** DIFFERENT COMPONENTS OF THE PROPOSED ARCHITECTURE *** This is the part where the different components of the proposed relation detection architecture are defined. In the case of RGB images, we have class probability distribution features, visual features, and the location ones. If we are considering depth images as well, we augment depth features too. """ # -- Visual features if self.has_visual: # -- Define faster R-CNN network and it's related feature extractors self.detector = ObjectDetector( classes=classes, mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox', use_resnet=use_resnet, thresh=thresh, max_per_img=64, ) self.roi_fmap_obj = load_vgg(pretrained=False).classifier # -- Define union features if self.use_vision: # -- UoBB pooling module self.union_boxes = UnionBoxesAndFeats( pooling_size=self.pooling_size, stride=16, dim=1024 if use_resnet else 512) # -- UoBB feature extractor roi_fmap = [ Flattener(), load_vgg(use_dropout=False, use_relu=False, use_linear=self.hidden_dim == 4096, pretrained=False).classifier, ] if self.hidden_dim != 4096: roi_fmap.append(nn.Linear(4096, self.hidden_dim)) self.roi_fmap = nn.Sequential(*roi_fmap) # -- Define visual features hidden layer self.visual_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(self.obj_dim * 2, self.FC_SIZE_VISUAL)), nn.ReLU(inplace=True), nn.Dropout(0.8) ]) self.visual_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_VISUAL # -- Location features if self.has_loc: # -- Define location features hidden layer self.location_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(self.LOC_INPUT_SIZE, self.FC_SIZE_LOC)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) self.location_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_LOC # -- Class features if self.has_class: if self.use_embed: # -- Define class embeddings embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim) self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim) self.obj_embed.weight.data = embed_vecs.clone() classme_input_dim = self.embed_dim if self.use_embed else self.num_classes # -- Define Class features hidden layer self.classme_hlayer = nn.Sequential(*[ xavier_init( nn.Linear(classme_input_dim * 2, self.FC_SIZE_CLASS)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) self.classme_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_CLASS # -- Depth features if self.has_depth: # -- Initialize depth backbone self.depth_backbone = DepthCNN(depth_model=self.depth_model, pretrained=self.pretrained_depth) # -- Create a relation head which is used to carry on the feature extraction # from RoIs of depth features self.depth_rel_head = self.depth_backbone.get_classifier() # -- Define depth features hidden layer self.depth_rel_hlayer = nn.Sequential(*[ xavier_init( nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)), nn.ReLU(inplace=True), nn.Dropout(0.6), ]) self.depth_scale = ScaleLayer(1.0) features_size += self.FC_SIZE_DEPTH # -- Initialize frequency bias if needed if self.use_bias: self.freq_bias = FrequencyBias() # -- *** Fusion layer *** -- # -- A hidden layer for concatenated features (fusion features) self.fusion_hlayer = nn.Sequential(*[ xavier_init(nn.Linear(features_size, self.hidden_dim)), nn.ReLU(inplace=True), nn.Dropout(0.1) ]) # -- Final FC layer which predicts the relations self.rel_out = xavier_init( nn.Linear(self.hidden_dim, self.num_rels, bias=True)) # -- Freeze the user specified features if self.frz_visual: self.freeze_module(self.detector) self.freeze_module(self.roi_fmap_obj) self.freeze_module(self.visual_hlayer) if self.use_vision: self.freeze_module(self.roi_fmap) self.freeze_module(self.union_boxes.conv) if self.frz_class: self.freeze_module(self.classme_hlayer) if self.frz_loc: self.freeze_module(self.location_hlayer) if self.frz_depth: self.freeze_module(self.depth_backbone) self.freeze_module(self.depth_rel_head) self.freeze_module(self.depth_rel_hlayer)
class neural_motifs_sg2im_model(nn.Module): def __init__(self, args, ind_to_classes): super(neural_motifs_sg2im_model, self).__init__() self.args = args # define and initial detector self.detector = ObjectDetector( classes=ind_to_classes, num_gpus=args.num_gpus, mode='refinerels' if not args.use_proposals else 'proposals', use_resnet=args.use_resnet) if args.ckpt is not None: ckpt = torch.load(args.ckpt) optimistic_restore(self.detector, ckpt['state_dict']) self.detector.eval() # define and initial generator, image_discriminator, obj_discriminator, # and corresponding optimizer vocab = { 'object_idx_to_name': ind_to_classes, } self.model, model_kwargs = build_model(args) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate) self.obj_discriminator, d_obj_kwargs = build_obj_discriminator( args, vocab) self.img_discriminator, d_img_kwargs = build_img_discriminator(args) if self.obj_discriminator is not None: self.obj_discriminator.train() self.optimizer_d_obj = torch.optim.Adam( self.obj_discriminator.parameters(), lr=args.learning_rate) if self.img_discriminator is not None: self.img_discriminator.train() self.optimizer_d_img = torch.optim.Adam( self.img_discriminator.parameters(), lr=args.learning_rate) restore_path = None if args.restore_from_checkpoint: restore_path = '%s_with_model.pt' % args.checkpoint_name restore_path = os.path.join(args.output_dir, restore_path) if restore_path is not None and os.path.isfile(restore_path): print('Restoring from checkpoint:') print(restore_path) checkpoint = torch.load(restore_path) self.model.load_state_dict(checkpoint['model_state']) self.optimizer.load_state_dict(checkpoint['optim_state']) if self.obj_discriminator is not None: self.obj_discriminator.load_state_dict( checkpoint['d_obj_state']) self.optimizer_d_obj.load_state_dict( checkpoint['d_obj_optim_state']) if self.img_discriminator is not None: self.img_discriminator.load_state_dict( checkpoint['d_img_state']) self.optimizer_d_img.load_state_dict( checkpoint['d_img_optim_state']) t = checkpoint['counters']['t'] if 0 <= args.eval_mode_after <= t: self.model.eval() else: self.model.train() epoch = checkpoint['counters']['epoch'] else: t, epoch = 0, 0 checkpoint = { 'vocab': vocab, 'model_kwargs': model_kwargs, 'd_obj_kwargs': d_obj_kwargs, 'd_img_kwargs': d_img_kwargs, 'losses_ts': [], 'losses': defaultdict(list), 'd_losses': defaultdict(list), 'checkpoint_ts': [], 'train_batch_data': [], 'train_samples': [], 'train_iou': [], 'val_batch_data': [], 'val_samples': [], 'val_losses': defaultdict(list), 'val_iou': [], 'norm_d': [], 'norm_g': [], 'counters': { 't': None, 'epoch': None, }, 'model_state': None, 'model_best_state': None, 'optim_state': None, 'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None, 'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None, 'best_t': [], } self.t, self.epoch, self.checkpoint = t, epoch, checkpoint def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): # forward detector with timeit('detector forward', self.args.timing): result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) if result.is_none(): return ValueError("heck") # forward generator imgs = F.interpolate(x, size=self.args.image_size) objs = result.obj_preds boxes = result.rm_box_priors / BOX_SCALE obj_to_img = result.im_inds - image_offset obj_fmap = result.obj_fmap # check if all image have detection cnt = torch.zeros(len(imgs)).byte() cnt[obj_to_img] += 1 if (cnt > 0).sum() != len(imgs): print("some imgs have no detection") print(cnt) imgs = imgs[cnt] obj_to_img_new = obj_to_img.clone() for i in range(len(cnt)): if cnt[i] == 0: obj_to_img_new -= (obj_to_img > i).long() obj_to_img = obj_to_img_new with timeit('generator forward', self.args.timing): imgs_pred = self.model(obj_to_img, boxes, obj_fmap) # forward discriminators to train generator if self.obj_discriminator is not None: with timeit('d_obj forward for g', self.args.timing): g_scores_fake_crop, g_obj_scores_fake_crop = self.obj_discriminator( imgs_pred, objs, boxes, obj_to_img) if self.img_discriminator is not None: with timeit('d_img forward for g', self.args.timing): g_scores_fake_img = self.img_discriminator(imgs_pred) # forward discriminators to train discriminators if self.obj_discriminator is not None: imgs_fake = imgs_pred.detach() with timeit('d_obj forward for d', self.args.timing): d_scores_fake_crop, d_obj_scores_fake_crop = self.obj_discriminator( imgs_fake, objs, boxes, obj_to_img) d_scores_real_crop, d_obj_scores_real_crop = self.obj_discriminator( imgs, objs, boxes, obj_to_img) if self.img_discriminator is not None: imgs_fake = imgs_pred.detach() with timeit('d_img forward for d', self.args.timing): d_scores_fake_img = self.img_discriminator(imgs_fake) d_scores_real_img = self.img_discriminator(imgs) return Result(imgs=imgs, imgs_pred=imgs_pred, objs=objs, g_scores_fake_crop=g_scores_fake_crop, g_obj_scores_fake_crop=g_obj_scores_fake_crop, g_scores_fake_img=g_scores_fake_img, d_scores_fake_crop=d_scores_fake_crop, d_obj_scores_fake_crop=d_obj_scores_fake_crop, d_scores_real_crop=d_scores_real_crop, d_obj_scores_real_crop=d_obj_scores_real_crop, d_scores_fake_img=d_scores_fake_img, d_scores_real_img=d_scores_real_img) # return imgs, imgs_pred, objs, g_scores_fake_crop, g_obj_scores_fake_crop, g_scores_fake_img, d_scores_fake_crop, \ # d_obj_scores_fake_crop, d_scores_real_crop, d_obj_scores_real_crop, d_scores_fake_img, d_scores_real_img def __getitem__(self, batch): """ Hack to do multi-GPU training""" batch.scatter() if self.args.num_gpus == 1: return self(*batch[0]) replicas = nn.parallel.replicate(self, devices=list(range( self.args.num_gpus))) outputs = nn.parallel.parallel_apply( replicas, [batch[i] for i in range(self.args.num_gpus)]) if self.training: return gather_res(outputs, 0, dim=0) return outputs