コード例 #1
0
ファイル: rel_model_align.py プロジェクト: MitraTj/vrd_align
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True,
                 embed_dim=200, hidden_dim=256, pooling_dim=2048, use_resnet=False, thresh=0.01,
                 use_proposals=False, rec_dropout=0.0, use_bias=True, use_tanh=True,
                 limit_vision=True, sl_pretrain=False, eval_rel_objs=False, num_iter=-1, reduce_input=False, 
                 post_nms_thresh=0.5):
        super(RelModelAlign, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        # self.pooling_size = 7
        self.pooling_size = conf.pooling_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.num_iter = num_iter
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.sl_pretrain = sl_pretrain

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = DynamicFilterContext(self.classes, self.rel_classes, mode=self.mode,
                                            use_vision=self.use_vision, embed_dim=self.embed_dim, 
                                            hidden_dim=self.hidden_dim, obj_dim=self.obj_dim, 
                                            pooling_dim=self.pooling_dim, pooling_size=self.pooling_size, 
                                            dropout_rate=rec_dropout, 
                                            use_bias=self.use_bias, use_tanh=self.use_tanh,
                                            limit_vision=self.limit_vision,
                                            sl_pretrain = self.sl_pretrain,
                                            num_iter=self.num_iter,
                                            use_resnet=use_resnet,
                                            reduce_input=reduce_input,
                                            post_nms_thresh=post_nms_thresh)
コード例 #2
0
ファイル: load_detector.py プロジェクト: LUGUANSONG/i2g2i
        num_gpus=detector_num_gpus)
else:
    train, val, _ = VG.splits(num_val_im=conf.val_size,
                              filter_non_overlap=False,
                              filter_empty_rels=False,
                              use_proposals=conf.use_proposals)
    train_loader, val_loader = VGDataLoader.splits(
        train,
        val,
        batch_size=conf.batch_size,
        num_workers=conf.num_workers,
        num_gpus=detector_num_gpus)

detector = ObjectDetector(
    classes=train.ind_to_classes,
    num_gpus=detector_num_gpus,
    mode='refinerels' if not conf.use_proposals else 'proposals',
    use_resnet=conf.use_resnet)
# print(detector)
# os._exit(0)
detector.cuda()

# Note: if you're doing the stanford setup, you'll need to change this to freeze the lower layers
if conf.use_proposals:
    for n, param in detector.named_parameters():
        if n.startswith('features'):
            param.requires_grad = False

start_epoch = -1
if conf.ckpt is not None:
    ckpt = torch.load(conf.ckpt)
コード例 #3
0
ファイル: eval_detector.py プロジェクト: taksau/sglabv1
from config import ModelConfig, FG_FRACTION, RPN_FG_FRACTION, IM_SCALE, BOX_SCALE
from torch.nn import functional as F
from lib.fpn.box_utils import bbox_loss
import torch.backends.cudnn as cudnn
from lib.pytorch_misc import optimistic_restore, clip_grad_norm
from torch.optim.lr_scheduler import ReduceLROnPlateau
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
torch.cuda.set_device(0)
cudnn.benchmark = True
conf = ModelConfig()

train, val, _ = VG.splits(num_val_im=conf.val_size,
                          filter_non_overlap=False,
                          filter_empty_rels=False,
                          use_proposals=False)
train_loader, val_loader = VGDataLoader.splits(train,
                                               val,
                                               batch_size=1,
                                               num_workers=1,
                                               num_gpus=1)
detector = ObjectDetector(classes=train.ind_to_classes,
                          num_gpus=1,
                          mode='refinerels',
                          use_resnet=False)
detector.cuda()

ckpt = torch.load(conf.ckpt)
optimistic_restore(detector, ckpt['state_dict'])
detector.eval()
for batch in train_loader:
    results = detector[batch]
コード例 #4
0
ファイル: my_model_33.py プロジェクト: youngfly11/gbnet
    def __init__(self, classes, rel_classes, graph_path, emb_path, mode='sgdet', num_gpus=1, 
                 require_overlap_det=True, pooling_dim=4096, use_resnet=False, thresh=0.01,
                 use_proposals=False,
                 ggnn_rel_time_step_num=3,
                 ggnn_rel_hidden_dim=512,
                 ggnn_rel_output_dim=512, use_knowledge=True, use_embedding=True, refine_obj_cls=False,
                 rel_counts_path=None, class_volume=1.0, top_k_to_keep=5, normalize_messages=True):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param require_overlap_det: Whether two objects must intersect
        """
        super(KERN, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.pooling_size = 7
        self.obj_dim = 2048 if use_resnet else 4096
        self.rel_dim = self.obj_dim
        self.pooling_dim = pooling_dim

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64
        )


        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        self.ggnn_rel_reason = GGNNRelReason(mode=self.mode, 
                                             num_obj_cls=len(self.classes), 
                                             num_rel_cls=len(rel_classes), 
                                             obj_dim=self.obj_dim, 
                                             rel_dim=self.rel_dim, 
                                             time_step_num=ggnn_rel_time_step_num, 
                                             hidden_dim=ggnn_rel_hidden_dim, 
                                             output_dim=ggnn_rel_output_dim,
                                             emb_path=emb_path,
                                             graph_path=graph_path, 
                                             refine_obj_cls=refine_obj_cls, 
                                             use_knowledge=use_knowledge, 
                                             use_embedding=use_embedding,
                                             top_k_to_keep=top_k_to_keep,
                                             normalize_messages=normalize_messages
                                            )

        if rel_counts_path is not None:
            with open(rel_counts_path, 'rb') as fin:
                rel_counts = pickle.load(fin)
            beta = (class_volume - 1.0) / class_volume
            self.rel_class_weights = (1.0 - beta) / (1 - (beta ** rel_counts))
            self.rel_class_weights *= float(self.num_rels) / np.sum(self.rel_class_weights)
        else:
            self.rel_class_weights = np.ones((self.num_rels,))
        self.rel_class_weights = Variable(torch.from_numpy(self.rel_class_weights).float().cuda(), requires_grad=False)
コード例 #5
0
ファイル: rel_model.py プロジェクト: ht014/lsbr
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 gnn=True,
                 reachability=False,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.reachability = reachability
        self.gnn = gnn
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.global_embedding = EmbeddingImagenet(4096)
        self.global_logist = nn.Linear(4096, 151,
                                       bias=True)  # CosineLinear(4096,150)#
        self.global_logist.weight = torch.nn.init.xavier_normal(
            self.global_logist.weight, gain=1.0)

        self.disc_center = DiscCentroidsLoss(self.num_rels,
                                             self.pooling_dim + 256)
        self.meta_classify = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim + 256, num_classes=self.num_rels)

        # self.global_rel_logist = nn.Linear(4096, 50 , bias=True)
        # self.global_rel_logist.weight = torch.nn.init.xavier_normal(self.global_rel_logist.weight, gain=1.0)

        # self.global_logist = CosineLinear(4096,150)
        self.global_sub_additive = nn.Linear(4096, 1, bias=True)
        self.global_obj_additive = nn.Linear(4096, 1, bias=True)

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        self.edge_coordinate_embedding = nn.Sequential(*[
            nn.BatchNorm1d(5, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(5, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes,
                                         self.pooling_dim * 2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(4096 + 256, 51, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)

        self.node_transform = nn.Linear(4096, 256, bias=True)
        self.edge_transform = nn.Linear(4096, 256, bias=True)
        # self.rel_compress = CosineLinear(self.pooling_dim+256, self.num_rels)
        # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        if self.gnn:
            self.graph_network_node = GraphNetwork(4096)
            self.graph_network_edge = GraphNetwork()
            if self.training:
                self.graph_network_node.train()
                self.graph_network_edge.train()
            else:
                self.graph_network_node.eval()
                self.graph_network_edge.eval()
        self.edge_sim_network = nn.Linear(4096, 1, bias=True)
        self.metric_net = MetricLearning()
コード例 #6
0
ファイル: model.py プロジェクト: LUGUANSONG/i2g2i
    def __init__(self, args, ind_to_classes):
        super(neural_motifs_sg2im_model, self).__init__()
        self.args = args

        # define and initial detector
        self.detector = ObjectDetector(
            classes=ind_to_classes,
            num_gpus=args.num_gpus,
            mode='refinerels' if not args.use_proposals else 'proposals',
            use_resnet=args.use_resnet)
        if args.ckpt is not None:
            ckpt = torch.load(args.ckpt)
            optimistic_restore(self.detector, ckpt['state_dict'])
        self.detector.eval()

        # define and initial generator, image_discriminator, obj_discriminator,
        # and corresponding optimizer
        vocab = {
            'object_idx_to_name': ind_to_classes,
        }
        self.model, model_kwargs = build_model(args)

        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=args.learning_rate)

        self.obj_discriminator, d_obj_kwargs = build_obj_discriminator(
            args, vocab)
        self.img_discriminator, d_img_kwargs = build_img_discriminator(args)

        if self.obj_discriminator is not None:
            self.obj_discriminator.train()
            self.optimizer_d_obj = torch.optim.Adam(
                self.obj_discriminator.parameters(), lr=args.learning_rate)

        if self.img_discriminator is not None:
            self.img_discriminator.train()
            self.optimizer_d_img = torch.optim.Adam(
                self.img_discriminator.parameters(), lr=args.learning_rate)

        restore_path = None
        if args.restore_from_checkpoint:
            restore_path = '%s_with_model.pt' % args.checkpoint_name
            restore_path = os.path.join(args.output_dir, restore_path)
        if restore_path is not None and os.path.isfile(restore_path):
            print('Restoring from checkpoint:')
            print(restore_path)
            checkpoint = torch.load(restore_path)
            self.model.load_state_dict(checkpoint['model_state'])
            self.optimizer.load_state_dict(checkpoint['optim_state'])

            if self.obj_discriminator is not None:
                self.obj_discriminator.load_state_dict(
                    checkpoint['d_obj_state'])
                self.optimizer_d_obj.load_state_dict(
                    checkpoint['d_obj_optim_state'])

            if self.img_discriminator is not None:
                self.img_discriminator.load_state_dict(
                    checkpoint['d_img_state'])
                self.optimizer_d_img.load_state_dict(
                    checkpoint['d_img_optim_state'])

            t = checkpoint['counters']['t']
            if 0 <= args.eval_mode_after <= t:
                self.model.eval()
            else:
                self.model.train()
            epoch = checkpoint['counters']['epoch']
        else:
            t, epoch = 0, 0
            checkpoint = {
                'vocab': vocab,
                'model_kwargs': model_kwargs,
                'd_obj_kwargs': d_obj_kwargs,
                'd_img_kwargs': d_img_kwargs,
                'losses_ts': [],
                'losses': defaultdict(list),
                'd_losses': defaultdict(list),
                'checkpoint_ts': [],
                'train_batch_data': [],
                'train_samples': [],
                'train_iou': [],
                'val_batch_data': [],
                'val_samples': [],
                'val_losses': defaultdict(list),
                'val_iou': [],
                'norm_d': [],
                'norm_g': [],
                'counters': {
                    't': None,
                    'epoch': None,
                },
                'model_state': None,
                'model_best_state': None,
                'optim_state': None,
                'd_obj_state': None,
                'd_obj_best_state': None,
                'd_obj_optim_state': None,
                'd_img_state': None,
                'd_img_best_state': None,
                'd_img_optim_state': None,
                'best_t': [],
            }

        self.t, self.epoch, self.checkpoint = t, epoch, checkpoint
コード例 #7
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 obj_dim=2048,
                 pooling_dim=4096,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=True,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True,
                 spatial_dim=128,
                 graph_constrain=True,
                 mp_iter_num=1):
        """
        Args:
            mp_iter_num: integer, number of message passing iteration
        """
        super(FckModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = obj_dim
        self.pooling_dim = 2048 if use_resnet else 4096
        self.spatial_dim = spatial_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.graph_cons = graph_constrain
        self.mp_iter_num = mp_iter_num

        classes_word_vec = obj_edge_vectors(self.classes, wv_dim=embed_dim)
        self.classes_word_embedding = nn.Embedding(self.num_classes, embed_dim)
        self.classes_word_embedding.weight.data = classes_word_vec.clone()
        self.classes_word_embedding.weight.requires_grad = False

        # the last one is dirty bit
        self.rel_mem = nn.Embedding(self.num_rels, self.obj_dim + 1)
        self.rel_mem.weight.data[:, -1] = 0

        if mode == 'sgdet':
            if use_proposals:
                obj_detector_mode = 'proposals'
            else:
                obj_detector_mode = 'refinerels'
        else:
            obj_detector_mode = 'gtbox'

        self.detector = ObjectDetector(
            classes=classes,
            mode=obj_detector_mode,
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512,
                                              use_feats=False)
        self.spatial_fc = nn.Sequential(*[
            nn.Linear(4, spatial_dim),
            nn.BatchNorm1d(spatial_dim, momentum=BATCHNORM_MOMENTUM / 10.),
            nn.ReLU(inplace=True)
        ])
        self.word_fc = nn.Sequential(*[
            nn.Linear(2 * embed_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim, momentum=BATCHNORM_MOMENTUM / 10.),
            nn.ReLU(inplace=True)
        ])
        # union box feats
        feats_dim = obj_dim + spatial_dim + hidden_dim
        self.relpn_fc = nn.Linear(feats_dim, 2)
        self.relcnn_fc1 = nn.Sequential(
            *[nn.Linear(feats_dim, feats_dim),
              nn.ReLU(inplace=True)])
        self.box_mp_fc = nn.Sequential(*[
            nn.Linear(obj_dim, obj_dim),
        ])
        self.sub_rel_mp_fc = nn.Sequential(*[nn.Linear(feats_dim, obj_dim)])

        self.obj_rel_mp_fc = nn.Sequential(*[
            nn.Linear(feats_dim, obj_dim),
        ])

        self.mp_atten_fc = nn.Sequential(*[
            nn.Linear(feats_dim + obj_dim, obj_dim),
            nn.ReLU(inplace=True),
            nn.Linear(obj_dim, 1)
        ])

        self.cls_fc = nn.Linear(obj_dim, self.num_classes)
        self.relcnn_fc2 = nn.Linear(
            feats_dim, self.num_rels if self.graph_cons else 2 * self.num_rels)

        if use_resnet:
            #deprecate
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                load_vgg(
                    use_dropout=False,
                    use_relu=False,
                    use_linear=self.obj_dim == 4096,
                    pretrained=False,
                ).classifier,
                nn.Linear(self.pooling_dim, self.obj_dim)
            ]
            self.roi_fmap = nn.Sequential(*roi_fmap)
コード例 #8
0
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True,
                 embed_dim=200, hidden_dim=256, pooling_dim=2048,
                 nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01,
                 use_proposals=False, pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True,
                 limit_vision=True):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision=limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        # print('REL MODEL CONSTRUCTOR: 1')
        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        # print('REL MODEL CONSTRUCTOR: 2')
        self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode,
                                         embed_dim=self.embed_dim, hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout,
                                         order=order,
                                         pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
                                         pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)
        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)
        # print('REL MODEL CONSTRUCTOR: 3')
        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier
        # print('REL MODEL CONSTRUCTOR: 4')
        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()
        # print('REL MODEL CONSTRUCTOR: 5')
        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
コード例 #9
0
if conf.test:
    print("test data!")
    val = test
train_loader, val_loader = VGDataLoader.splits(train,
                                               val,
                                               mode='rel',
                                               batch_size=conf.batch_size,
                                               num_workers=conf.num_workers,
                                               num_gpus=conf.num_gpus)

fg_matrix, bg_matrix = get_counts(train_data=train, must_overlap=MUST_OVERLAP)

detector = ObjectDetector(
    classes=train.ind_to_classes,
    num_gpus=conf.num_gpus,
    mode='rpntrain' if not conf.use_proposals else 'proposals',
    use_resnet=conf.use_resnet,
    nms_filter_duplicates=True,
    thresh=0.01)
detector.eval()
detector.cuda()

classifier = ObjectDetector(classes=train.ind_to_classes,
                            num_gpus=conf.num_gpus,
                            mode='gtbox',
                            use_resnet=conf.use_resnet,
                            nms_filter_duplicates=True,
                            thresh=0.01)
classifier.eval()
classifier.cuda()
コード例 #10
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=False,
                 embed_dim=200,
                 hidden_dim=256,
                 obj_dim=2048,
                 pooling_dim=4096,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=True,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True,
                 spatial_dim=128,
                 mp_iter_num=1,
                 trim_graph=True):
        """
        Args:
            mp_iter_num: integer, number of message passing iteration
            trim_graph: boolean, trim graph in rel pn
        """
        super(FckModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = obj_dim
        self.pooling_dim = 2048 if use_resnet else 4096
        self.spatial_dim = spatial_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.mp_iter_num = mp_iter_num
        self.trim_graph = trim_graph

        classes_word_vec = obj_edge_vectors(self.classes, wv_dim=embed_dim)
        self.classes_word_embedding = nn.Embedding(self.num_classes, embed_dim)
        self.classes_word_embedding.weight.data = classes_word_vec.clone()
        self.classes_word_embedding.weight.requires_grad = False

        #fg_matrix, bg_matrix = get_counts()
        #rel_obj_distribution = fg_matrix / (fg_matrix.sum(2)[:, :, None] + 1e-5)
        #rel_obj_distribution = torch.FloatTensor(rel_obj_distribution)
        #rel_obj_distribution = rel_obj_distribution.view(-1, self.num_rels)
        #
        #self.rel_obj_distribution = nn.Embedding(rel_obj_distribution.size(0), self.num_rels)
        ## (#obj_class * #obj_class, #rel_class)
        #self.rel_obj_distribution.weight.data = rel_obj_distribution

        if mode == 'sgdet':
            if use_proposals:
                obj_detector_mode = 'proposals'
            else:
                obj_detector_mode = 'refinerels'
        else:
            obj_detector_mode = 'gtbox'

        self.detector = ObjectDetector(
            classes=classes,
            mode=obj_detector_mode,
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512,
                                              use_feats=False)
        self.spatial_fc = nn.Sequential(*[
            nn.Linear(4, spatial_dim),
            nn.BatchNorm1d(spatial_dim, momentum=BATCHNORM_MOMENTUM / 10.),
            nn.ReLU(inplace=True)
        ])
        self.word_fc = nn.Sequential(*[
            nn.Linear(2 * embed_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim, momentum=BATCHNORM_MOMENTUM / 10.),
            nn.ReLU(inplace=True)
        ])
        # union box feats
        feats_dim = obj_dim + spatial_dim + hidden_dim
        self.relpn_fc = nn.Linear(feats_dim, 2)
        self.relcnn_fc1 = nn.Sequential(
            *[nn.Linear(feats_dim, feats_dim),
              nn.ReLU(inplace=True)])

        # v2 model---------
        self.box_mp_fc = nn.Sequential(*[
            nn.Linear(obj_dim, obj_dim),
        ])
        self.sub_rel_mp_fc = nn.Sequential(*[nn.Linear(feats_dim, obj_dim)])

        self.obj_rel_mp_fc = nn.Sequential(*[
            nn.Linear(feats_dim, obj_dim),
        ])

        self.mp_atten_fc = nn.Sequential(*[
            nn.Linear(feats_dim + obj_dim, obj_dim),
            nn.ReLU(inplace=True),
            nn.Linear(obj_dim, 1)
        ])
        # v2 model----------

        self.cls_fc = nn.Linear(obj_dim, self.num_classes)

        self.relcnn_fc2 = nn.Linear(feats_dim, self.num_rels)

        # v3 model -----------

        self.mem_module = MemoryRNN(classes=classes,
                                    rel_classes=rel_classes,
                                    inputs_dim=feats_dim,
                                    hidden_dim=hidden_dim,
                                    recurrent_dropout_probability=.0)
        # v3 model -----------

        if use_resnet:
            # deprecate
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                load_vgg(
                    use_dropout=False,
                    use_relu=False,
                    use_linear=self.obj_dim == 4096,
                    pretrained=False,
                ).classifier,
                nn.Linear(self.pooling_dim, self.obj_dim)
            ]
            self.roi_fmap = nn.Sequential(*roi_fmap)
コード例 #11
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        Args:
            classes: list, list of 151 object class names(including background)
            rel_classes: list, list of 51 predicate names( including background(norelationship))
            mode: string, 'sgdet', 'predcls' or 'sgcls'
            num_gpus: integer, number of GPUs to use
            use_vision: boolean, whether to use vision in the final product
            require_overlap_det: boolean, whether two object must intersect
            embed_dim: integer, number of dimension for all embeddings
            hidden_dim: integer, hidden size of LSTM
            pooling_dim: integer, outputsize of vgg fc layer
            nl_obj: integer, number of object context layer, 2 in paper
            nl_edge: integer, number of edge context layer, 4 in paper
            use_resnet: integer, use resnet for backbone
            order: string, value must be in ('size', 'confidence', 'random', 'leftright'), order of RoIs
            thresh: float, threshold for scores of boxes
                if score of box smaller than thresh, then it will be abandoned
            use_proposals: boolean, whether to use proposals
            pass_in_obj_feats_to_decoder: boolean, whether to pass object features to decoder RNN
            pass_in_obj_feats_to_edge: boolean, whether to pass object features to edge context RNN
            rec_dropout: float, dropout rate in RNN
            use_bias: boolean,
            use_tanh: boolean,
            limit_vision: boolean,
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes,
                                         self.pooling_dim * 2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
コード例 #12
0

        
コード例 #13
0
ファイル: rel_model3.py プロジェクト: ht014/lsbr
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 model_path='',
                 reachability=False,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 init_center=False,
                 limit_vision=True):

        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.init_center = init_center
        self.pooling_size = 7
        self.model_path = model_path
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim
        self.centroids = None
        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.global_embedding = EmbeddingImagenet(4096)
        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)
        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)
        self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim)
        self.meta_classify = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim, num_classes=self.num_rels)
        self.disc_center_g = DiscCentroidsLoss(self.num_classes,
                                               self.pooling_dim)
        self.meta_classify_g = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim, num_classes=self.num_classes)
        self.global_sub_additive = nn.Linear(4096, 1, bias=True)
        self.global_obj_additive = nn.Linear(4096, 1, bias=True)
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        self.global_logist = nn.Linear(self.pooling_dim,
                                       self.num_classes,
                                       bias=True)  # CosineLinear(4096,150)#
        self.global_logist.weight = torch.nn.init.xavier_normal(
            self.global_logist.weight, gain=1.0)

        self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2)
        self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        self.class_num = torch.zeros(len(self.classes))
        self.centroids = torch.zeros(len(self.classes),
                                     self.pooling_dim).cuda()
コード例 #14
0
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, require_overlap_det=True,
                 embed_dim=200, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        """
        super(NODIS, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim

        self.obj_dim = 2048 if use_resnet else 4096


        self.order = 'random'

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = O_NODE(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, obj_dim=self.obj_dim, order=order)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier
            self.roi_avg_pool = nn.AvgPool2d(kernel_size=7, stride=0)
        ###################################
        embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim)
        self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
        self.obj_embed.weight.data = embed_vecs.clone()

        self.obj_embed2 = nn.Embedding(self.num_classes, self.embed_dim)
        self.obj_embed2.weight.data = embed_vecs.clone()

        self.lstm_visual = nn.LSTM(input_size=1536, hidden_size=512)
        self.lstm_semantic = nn.LSTM(input_size=400, hidden_size=512)
        self.odeBlock = odeBlock(odeFunc1(bidirectional=True))

        self.fc_predicate = nn.Sequential(nn.Linear(1024, 512),
                                          nn.ReLU(inplace=False),
                                          nn.Linear(512, 51),
                                          nn.ReLU(inplace=False))
コード例 #15
0
        val,
        batch_size=conf.batch_size,
        num_workers=conf.num_workers,
        num_gpus=conf.num_gpus)
else:
    train, val = HICO.splits(num_val_im=conf.val_size)
    train_loader, val_loader = HICODataLoader.splits(
        train,
        val,
        batch_size=conf.batch_size,
        num_workers=conf.num_workers,
        num_gpus=conf.num_gpus)

detector = ObjectDetector(
    classes=train.ind_to_classes,
    num_gpus=conf.num_gpus,
    mode='rpntrain' if not conf.use_proposals else 'proposals',
    use_resnet=conf.use_resnet)
detector.cuda()

start_epoch = -1
if conf.ckpt is not None:
    ckpt = torch.load(conf.ckpt)
    if optimistic_restore(detector, ckpt['state_dict']):
        start_epoch = ckpt['epoch']


def val_epoch():
    detector.eval()
    # all_boxes is a list of length number-of-classes.
    # Each list element is a list of length number-of-images.
コード例 #16
0
        num_gpus=conf.num_gpus)
else:
    train, val, _ = VG.splits(num_val_im=conf.val_size,
                              filter_non_overlap=False,
                              filter_empty_rels=False,
                              use_proposals=conf.use_proposals)
    train_loader, val_loader = VGDataLoader.splits(
        train,
        val,
        batch_size=conf.batch_size,
        num_workers=conf.num_workers,
        num_gpus=conf.num_gpus)

detector = ObjectDetector(
    classes=train.ind_to_classes,
    num_gpus=conf.num_gpus,
    mode='rpntrain' if not conf.use_proposals else 'proposals',
    use_resnet=conf.use_resnet)
detector.cuda()

# Note: if you're doing the stanford setup, you'll need to change this to freeze the lower layers
if conf.use_proposals:
    for n, param in detector.named_parameters():
        if n.startswith('features'):
            param.requires_grad = False

optimizer = optim.SGD([p for p in detector.parameters() if p.requires_grad],
                      weight_decay=conf.l2,
                      lr=conf.lr * conf.num_gpus * conf.batch_size,
                      momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer,
コード例 #17
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 require_overlap_det=True,
                 pooling_dim=4096,
                 use_resnet=False,
                 thresh=0.01,
                 use_proposals=False,
                 use_ggnn_obj=False,
                 ggnn_obj_time_step_num=3,
                 ggnn_obj_hidden_dim=512,
                 ggnn_obj_output_dim=512,
                 use_ggnn_rel=False,
                 ggnn_rel_time_step_num=3,
                 ggnn_rel_hidden_dim=512,
                 ggnn_rel_output_dim=512,
                 use_obj_knowledge=True,
                 use_rel_knowledge=True,
                 obj_knowledge='',
                 rel_knowledge=''):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param require_overlap_det: Whether two objects must intersect
        """
        super(KERN, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.pooling_size = 7
        self.obj_dim = 2048 if use_resnet else 4096
        self.rel_dim = self.obj_dim
        self.pooling_dim = pooling_dim

        self.use_ggnn_obj = use_ggnn_obj
        self.use_ggnn_rel = use_ggnn_rel

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64)

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        if self.use_ggnn_obj:
            self.ggnn_obj_reason = GGNNObjReason(
                mode=self.mode,
                num_obj_cls=len(self.classes),
                obj_dim=self.obj_dim,
                time_step_num=ggnn_obj_time_step_num,
                hidden_dim=ggnn_obj_hidden_dim,
                output_dim=ggnn_obj_output_dim,
                use_knowledge=use_obj_knowledge,
                knowledge_matrix=obj_knowledge)

        if self.use_ggnn_rel:
            self.ggnn_rel_reason = GGNNRelReason(
                mode=self.mode,
                num_obj_cls=len(self.classes),
                num_rel_cls=len(rel_classes),
                obj_dim=self.obj_dim,
                rel_dim=self.rel_dim,
                time_step_num=ggnn_rel_time_step_num,
                hidden_dim=ggnn_rel_hidden_dim,
                output_dim=ggnn_obj_output_dim,
                use_knowledge=use_rel_knowledge,
                knowledge_matrix=rel_knowledge)
        else:
            self.vr_fc_cls = VRFC(self.mode, self.rel_dim, len(self.classes),
                                  len(self.rel_classes))
コード例 #18
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.hook_for_grad = False
        self.gradients = []

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        self.ort_embedding = torch.autograd.Variable(
            get_ort_embeds(self.num_classes, 200).cuda())
        embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim)
        self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
        self.obj_embed.weight.data = embed_vecs.clone()

        # This probably doesn't help it much
        self.pos_embed = nn.Sequential(*[
            nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        self.merge_obj_feats = nn.Sequential(
            nn.Linear(self.obj_dim + self.embed_dim + 128, self.hidden_dim),
            nn.ReLU())

        # self.trans = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim//4),
        #                             LayerNorm(self.hidden_dim//4), nn.ReLU(),
        #                             nn.Linear(self.hidden_dim//4, self.hidden_dim))

        self.get_phr_feats = nn.Linear(self.pooling_dim, self.hidden_dim)

        self.embeddings4lstm = nn.Embedding(self.num_classes, self.embed_dim)

        self.lstm = nn.LSTM(input_size=self.hidden_dim + self.embed_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=1)

        self.obj_mps1 = Message_Passing4OBJ(self.hidden_dim)
        # self.obj_mps2 = Message_Passing4OBJ(self.hidden_dim)
        self.get_boxes_encode = Boxes_Encode(64)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        # self.obj_classify_head = nn.Linear(self.pooling_dim, self.num_classes)

        # self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim//2)
        # self.post_emb_s.weight = torch.nn.init.xavier_normal(self.post_emb_s.weight, gain=1.0)
        # self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim//2)
        # self.post_emb_o.weight = torch.nn.init.xavier_normal(self.post_emb_o.weight, gain=1.0)
        # self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim//2)
        # self.merge_obj_high.weight = torch.nn.init.xavier_normal(self.merge_obj_high.weight, gain=1.0)
        # self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim//2)
        # self.merge_obj_low.weight = torch.nn.init.xavier_normal(self.merge_obj_low.weight, gain=1.0)
        # self.rel_compress = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True)
        # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        # self.freq_gate = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True)
        # self.freq_gate.weight = torch.nn.init.xavier_normal(self.freq_gate.weight, gain=1.0)

        self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim)
        self.post_emb_s.weight = torch.nn.init.xavier_normal(
            self.post_emb_s.weight, gain=1.0)
        self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim)
        self.post_emb_o.weight = torch.nn.init.xavier_normal(
            self.post_emb_o.weight, gain=1.0)
        self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim)
        self.merge_obj_high.weight = torch.nn.init.xavier_normal(
            self.merge_obj_high.weight, gain=1.0)
        self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim,
                                       self.pooling_dim)
        self.merge_obj_low.weight = torch.nn.init.xavier_normal(
            self.merge_obj_low.weight, gain=1.0)
        self.rel_compress = nn.Linear(self.pooling_dim + 64,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        self.freq_gate = nn.Linear(self.pooling_dim + 64,
                                   self.num_rels,
                                   bias=True)
        self.freq_gate.weight = torch.nn.init.xavier_normal(
            self.freq_gate.weight, gain=1.0)
        # self.ranking_module = nn.Sequential(nn.Linear(self.pooling_dim + 64, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, 1))
        if self.use_bias:
            self.freq_bias = FrequencyBias()
コード例 #19
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=False,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=4096,
                 use_resnet=False,
                 thresh=0.01,
                 use_proposals=False,
                 use_bias=True,
                 limit_vision=True,
                 depth_model=None,
                 pretrained_depth=False,
                 active_features=None,
                 frozen_features=None,
                 use_embed=False,
                 **kwargs):
        """
        :param classes: object classes
        :param rel_classes: relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: enable the contribution of union of bounding boxes
        :param require_overlap_det: whether two objects must intersect
        :param embed_dim: word2vec embeddings dimension
        :param hidden_dim: dimension of the fusion hidden layer
        :param use_resnet: use resnet as faster-rcnn's backbone
        :param thresh: faster-rcnn related threshold (Threshold for calling it a good box)
        :param use_proposals: whether to use region proposal candidates
        :param use_bias: enable frequency bias
        :param limit_vision: use truncated version of UoBB features
        :param depth_model: provided architecture for depth feature extraction
        :param pretrained_depth: whether the depth feature extractor should be initialized with ImageNet weights
        :param active_features: what set of features should be enabled (e.g. 'vdl' : visual, depth, and location features)
        :param frozen_features: what set of features should be frozen (e.g. 'd' : depth)
        :param use_embed: use word2vec embeddings
        """
        RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus,
                              require_overlap_det, active_features,
                              frozen_features)
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.use_vision = use_vision
        self.use_bias = use_bias
        self.limit_vision = limit_vision

        # -- Store depth related parameters
        assert depth_model in DEPTH_MODELS
        self.depth_model = depth_model
        self.pretrained_depth = pretrained_depth
        self.depth_pooling_dim = DEPTH_DIMS[self.depth_model]
        self.use_embed = use_embed
        self.detector = nn.Module()
        features_size = 0

        # -- Check whether ResNet is selected as faster-rcnn's backbone
        if use_resnet:
            raise ValueError(
                "The current model does not support ResNet as the Faster-RCNN's backbone."
            )
        """ *** DIFFERENT COMPONENTS OF THE PROPOSED ARCHITECTURE *** 
        This is the part where the different components of the proposed relation detection 
        architecture are defined. In the case of RGB images, we have class probability distribution
        features, visual features, and the location ones. If we are considering depth images as well,
        we augment depth features too. """

        # -- Visual features
        if self.has_visual:
            # -- Define faster R-CNN network and it's related feature extractors
            self.detector = ObjectDetector(
                classes=classes,
                mode=('proposals' if use_proposals else 'refinerels')
                if mode == 'sgdet' else 'gtbox',
                use_resnet=use_resnet,
                thresh=thresh,
                max_per_img=64,
            )
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

            # -- Define union features
            if self.use_vision:
                # -- UoBB pooling module
                self.union_boxes = UnionBoxesAndFeats(
                    pooling_size=self.pooling_size,
                    stride=16,
                    dim=1024 if use_resnet else 512)

                # -- UoBB feature extractor
                roi_fmap = [
                    Flattener(),
                    load_vgg(use_dropout=False,
                             use_relu=False,
                             use_linear=self.hidden_dim == 4096,
                             pretrained=False).classifier,
                ]
                if self.hidden_dim != 4096:
                    roi_fmap.append(nn.Linear(4096, self.hidden_dim))
                self.roi_fmap = nn.Sequential(*roi_fmap)

            # -- Define visual features hidden layer
            self.visual_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.obj_dim * 2, self.FC_SIZE_VISUAL)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.8)
            ])
            self.visual_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_VISUAL

        # -- Location features
        if self.has_loc:
            # -- Define location features hidden layer
            self.location_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.LOC_INPUT_SIZE, self.FC_SIZE_LOC)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.location_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_LOC

        # -- Class features
        if self.has_class:
            if self.use_embed:
                # -- Define class embeddings
                embed_vecs = obj_edge_vectors(self.classes,
                                              wv_dim=self.embed_dim)
                self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
                self.obj_embed.weight.data = embed_vecs.clone()

            classme_input_dim = self.embed_dim if self.use_embed else self.num_classes
            # -- Define Class features hidden layer
            self.classme_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(classme_input_dim * 2, self.FC_SIZE_CLASS)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.classme_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_CLASS

        # -- Depth features
        if self.has_depth:
            # -- Initialize depth backbone
            self.depth_backbone = DepthCNN(depth_model=self.depth_model,
                                           pretrained=self.pretrained_depth)

            # -- Create a relation head which is used to carry on the feature extraction
            # from RoIs of depth features
            self.depth_rel_head = self.depth_backbone.get_classifier()

            # -- Define depth features hidden layer
            self.depth_rel_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.6),
            ])
            self.depth_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_DEPTH

        # -- Initialize frequency bias if needed
        if self.use_bias:
            self.freq_bias = FrequencyBias()

        # -- *** Fusion layer *** --
        # -- A hidden layer for concatenated features (fusion features)
        self.fusion_hlayer = nn.Sequential(*[
            xavier_init(nn.Linear(features_size, self.hidden_dim)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        ])

        # -- Final FC layer which predicts the relations
        self.rel_out = xavier_init(
            nn.Linear(self.hidden_dim, self.num_rels, bias=True))

        # -- Freeze the user specified features
        if self.frz_visual:
            self.freeze_module(self.detector)
            self.freeze_module(self.roi_fmap_obj)
            self.freeze_module(self.visual_hlayer)
            if self.use_vision:
                self.freeze_module(self.roi_fmap)
                self.freeze_module(self.union_boxes.conv)

        if self.frz_class:
            self.freeze_module(self.classme_hlayer)

        if self.frz_loc:
            self.freeze_module(self.location_hlayer)

        if self.frz_depth:
            self.freeze_module(self.depth_backbone)
            self.freeze_module(self.depth_rel_head)
            self.freeze_module(self.depth_rel_hlayer)
コード例 #20
0
ファイル: model.py プロジェクト: LUGUANSONG/i2g2i
class neural_motifs_sg2im_model(nn.Module):
    def __init__(self, args, ind_to_classes):
        super(neural_motifs_sg2im_model, self).__init__()
        self.args = args

        # define and initial detector
        self.detector = ObjectDetector(
            classes=ind_to_classes,
            num_gpus=args.num_gpus,
            mode='refinerels' if not args.use_proposals else 'proposals',
            use_resnet=args.use_resnet)
        if args.ckpt is not None:
            ckpt = torch.load(args.ckpt)
            optimistic_restore(self.detector, ckpt['state_dict'])
        self.detector.eval()

        # define and initial generator, image_discriminator, obj_discriminator,
        # and corresponding optimizer
        vocab = {
            'object_idx_to_name': ind_to_classes,
        }
        self.model, model_kwargs = build_model(args)

        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=args.learning_rate)

        self.obj_discriminator, d_obj_kwargs = build_obj_discriminator(
            args, vocab)
        self.img_discriminator, d_img_kwargs = build_img_discriminator(args)

        if self.obj_discriminator is not None:
            self.obj_discriminator.train()
            self.optimizer_d_obj = torch.optim.Adam(
                self.obj_discriminator.parameters(), lr=args.learning_rate)

        if self.img_discriminator is not None:
            self.img_discriminator.train()
            self.optimizer_d_img = torch.optim.Adam(
                self.img_discriminator.parameters(), lr=args.learning_rate)

        restore_path = None
        if args.restore_from_checkpoint:
            restore_path = '%s_with_model.pt' % args.checkpoint_name
            restore_path = os.path.join(args.output_dir, restore_path)
        if restore_path is not None and os.path.isfile(restore_path):
            print('Restoring from checkpoint:')
            print(restore_path)
            checkpoint = torch.load(restore_path)
            self.model.load_state_dict(checkpoint['model_state'])
            self.optimizer.load_state_dict(checkpoint['optim_state'])

            if self.obj_discriminator is not None:
                self.obj_discriminator.load_state_dict(
                    checkpoint['d_obj_state'])
                self.optimizer_d_obj.load_state_dict(
                    checkpoint['d_obj_optim_state'])

            if self.img_discriminator is not None:
                self.img_discriminator.load_state_dict(
                    checkpoint['d_img_state'])
                self.optimizer_d_img.load_state_dict(
                    checkpoint['d_img_optim_state'])

            t = checkpoint['counters']['t']
            if 0 <= args.eval_mode_after <= t:
                self.model.eval()
            else:
                self.model.train()
            epoch = checkpoint['counters']['epoch']
        else:
            t, epoch = 0, 0
            checkpoint = {
                'vocab': vocab,
                'model_kwargs': model_kwargs,
                'd_obj_kwargs': d_obj_kwargs,
                'd_img_kwargs': d_img_kwargs,
                'losses_ts': [],
                'losses': defaultdict(list),
                'd_losses': defaultdict(list),
                'checkpoint_ts': [],
                'train_batch_data': [],
                'train_samples': [],
                'train_iou': [],
                'val_batch_data': [],
                'val_samples': [],
                'val_losses': defaultdict(list),
                'val_iou': [],
                'norm_d': [],
                'norm_g': [],
                'counters': {
                    't': None,
                    'epoch': None,
                },
                'model_state': None,
                'model_best_state': None,
                'optim_state': None,
                'd_obj_state': None,
                'd_obj_best_state': None,
                'd_obj_optim_state': None,
                'd_img_state': None,
                'd_img_best_state': None,
                'd_img_optim_state': None,
                'best_t': [],
            }

        self.t, self.epoch, self.checkpoint = t, epoch, checkpoint

    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        # forward detector
        with timeit('detector forward', self.args.timing):
            result = self.detector(x,
                                   im_sizes,
                                   image_offset,
                                   gt_boxes,
                                   gt_classes,
                                   gt_rels,
                                   proposals,
                                   train_anchor_inds,
                                   return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        # forward generator
        imgs = F.interpolate(x, size=self.args.image_size)
        objs = result.obj_preds
        boxes = result.rm_box_priors / BOX_SCALE
        obj_to_img = result.im_inds - image_offset
        obj_fmap = result.obj_fmap

        # check if all image have detection
        cnt = torch.zeros(len(imgs)).byte()
        cnt[obj_to_img] += 1
        if (cnt > 0).sum() != len(imgs):
            print("some imgs have no detection")
            print(cnt)
            imgs = imgs[cnt]
            obj_to_img_new = obj_to_img.clone()
            for i in range(len(cnt)):
                if cnt[i] == 0:
                    obj_to_img_new -= (obj_to_img > i).long()
            obj_to_img = obj_to_img_new

        with timeit('generator forward', self.args.timing):
            imgs_pred = self.model(obj_to_img, boxes, obj_fmap)

        # forward discriminators to train generator
        if self.obj_discriminator is not None:
            with timeit('d_obj forward for g', self.args.timing):
                g_scores_fake_crop, g_obj_scores_fake_crop = self.obj_discriminator(
                    imgs_pred, objs, boxes, obj_to_img)

        if self.img_discriminator is not None:
            with timeit('d_img forward for g', self.args.timing):
                g_scores_fake_img = self.img_discriminator(imgs_pred)

        # forward discriminators to train discriminators
        if self.obj_discriminator is not None:
            imgs_fake = imgs_pred.detach()
            with timeit('d_obj forward for d', self.args.timing):
                d_scores_fake_crop, d_obj_scores_fake_crop = self.obj_discriminator(
                    imgs_fake, objs, boxes, obj_to_img)
                d_scores_real_crop, d_obj_scores_real_crop = self.obj_discriminator(
                    imgs, objs, boxes, obj_to_img)

        if self.img_discriminator is not None:
            imgs_fake = imgs_pred.detach()
            with timeit('d_img forward for d', self.args.timing):
                d_scores_fake_img = self.img_discriminator(imgs_fake)
                d_scores_real_img = self.img_discriminator(imgs)

        return Result(imgs=imgs,
                      imgs_pred=imgs_pred,
                      objs=objs,
                      g_scores_fake_crop=g_scores_fake_crop,
                      g_obj_scores_fake_crop=g_obj_scores_fake_crop,
                      g_scores_fake_img=g_scores_fake_img,
                      d_scores_fake_crop=d_scores_fake_crop,
                      d_obj_scores_fake_crop=d_obj_scores_fake_crop,
                      d_scores_real_crop=d_scores_real_crop,
                      d_obj_scores_real_crop=d_obj_scores_real_crop,
                      d_scores_fake_img=d_scores_fake_img,
                      d_scores_real_img=d_scores_real_img)
        # return imgs, imgs_pred, objs, g_scores_fake_crop, g_obj_scores_fake_crop, g_scores_fake_img, d_scores_fake_crop, \
        #        d_obj_scores_fake_crop, d_scores_real_crop, d_obj_scores_real_crop, d_scores_fake_img, d_scores_real_img

    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.args.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self,
                                         devices=list(range(
                                             self.args.num_gpus)))
        outputs = nn.parallel.parallel_apply(
            replicas, [batch[i] for i in range(self.args.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs