Example #1
0
def test_net_data_api1(sess, net, output_dir, h_box, o_box, o_cls, h_score, o_score, image_id):
    detection = {}
    verb_to_HO_matrix, obj_to_HO_matrix = get_convert_matrix()
    # prediction_HO  = net.test_image_HO(sess, im_orig, blobs)
    # timers
    ones = np.ones([1, 600], np.float32)
    _t = {'im_detect': Timer(), 'misc': Timer()}
    last_img_id = -1
    count = 0

    fuse_res = tf.constant(0)

    obj_scores = tf.constant(0)
    objid = tf.constant(0)
    obj_scores = tf.constant(0)
    _t['im_detect'].tic()
    while True:
        _t['im_detect'].tic()

        from tensorflow.python.framework.errors_impl import InvalidArgumentError
        try:
            pH, pO, pSp, pVerbs, pSpHO, pFuse, f_obj_score, f_obj_cls, _h_box, _o_box, _o_cls, _h_score, _o_score, _image_id = sess.run([
                net.predictions["cls_prob_H"] if 'cls_prob_H' in net.predictions else h_box,
                net.predictions["cls_prob_O"] if 'cls_prob_O' in net.predictions else h_box,
                net.predictions["cls_prob_sp"] if 'cls_prob_sp' in net.predictions else h_box,
                net.predictions["cls_prob_hoi"] if 'cls_prob_hoi' in net.predictions else h_box,
                net.predictions["cls_prob_spverbs"] if 'cls_prob_spverbs' in net.predictions else h_box,
                fuse_res if 'cls_prob_sp' in net.predictions else h_box, obj_scores, objid,
                                        h_box, o_box, o_cls, h_score, o_score, image_id])
            # print(pFuse.shape, f_obj_score.shape, f_obj_cls.shape)
        except InvalidArgumentError as e:
            # cls_prob_HO = np.zeros(shape=[blobs['sp'].shape[0], self.num_classes])
            raise e
        except tf.errors.OutOfRangeError:
            print('END')
            break

        # if net.model_name.__contains__('cosine'):
        #     temp = [[_h_box[i], _o_box[i], _o_cls[i], 0, _h_score[i], _o_score[i], pH[i], pO[i], pSp[i], pVerbs[i], pSpHO[i], pFuse[i], f_obj_score[i], f_obj_cls[i]] for i in range(len(_h_box))]
        # else:
        temp = [[_h_box[i], _o_box[i], _o_cls[i], 0, _h_score[i], _o_score[i], pH[i], pO[i], pSp[i], pVerbs[i], pSpHO[i]] for i in range(len(_h_box))]

        # detection[_image_id] = temp
        if _image_id in detection:
            detection[_image_id].extend(temp)
        else:
            detection[_image_id] = temp

        _t['im_detect'].toc()
        count += 1

        print('\rmodel: {} im_detect: {:d}/{:d}  {:d}, {:.3f}s'.format(net.model_name, count, 10566, _image_id, _t['im_detect'].average_time), end='', flush=True)

    # TODO remove
    # pickle.dump(detection, open('test_new.pkl', 'wb'))
    pickle.dump(detection, open(output_dir, "wb"))
    del detection
    import gc
    gc.collect()
Example #2
0
 def reset_classes(self):
     from ult.tools import get_convert_matrix
     verb_to_HO_matrix, obj_to_HO_matrix = get_convert_matrix(self.verb_num_classes, self.obj_num_classes)
     self.obj_to_HO_matrix = tf.constant(obj_to_HO_matrix, tf.float32)
     self.verb_to_HO_matrix = tf.constant(verb_to_HO_matrix, tf.float32)
     self.gt_obj_class = tf.cast(tf.matmul(self.gt_class_HO, self.obj_to_HO_matrix, transpose_b=True) > 0,
                                 tf.float32)
     self.gt_verb_class = tf.cast(tf.matmul(self.gt_class_HO, self.verb_to_HO_matrix, transpose_b=True) > 0,
                                  tf.float32)
                    # print(preds, element[4], element[5])
                    temp.append(preds[begin - 1 + i] * element[4] * element[5])
                    total.append(temp)
                    score.append(preds[begin - 1 + i] * element[4] *
                                 element[5])

        idx = np.argsort(score, axis=0)[::-1]
        for i_idx in range(min(len(idx), 19999)):
            all_boxes.append(total[idx[i_idx]])
    savefile = HICO_dir + 'detections_' + str(classid).zfill(2) + '.mat'
    # print('length:', classid, len(all_boxes))
    sio.savemat(savefile, {'all_boxes': all_boxes})
    return all_boxes


verb_to_HO_matrix, obj_to_HO_matrix = get_convert_matrix()
hoi_2_obj = {}
for i in range(600):
    for j in range(80):
        if obj_to_HO_matrix[j][i] > 0:
            hoi_2_obj[i] = j


def obtain_fuse_preds(element, fuse_type):
    preds = element[3]
    if fuse_type != 'preds':
        pH = element[6]
        pO = element[7]
        pSp = element[8]
        pHoi = element[9]
    if fuse_type == 'preds':
Example #4
0
    def __init__(self, model_name):
        self.model_name = model_name
        self.visualize = {}
        self.test_visualize = {}
        self.intermediate = {}
        self.predictions = {}
        self.score_summaries = {}
        self.event_summaries = {}
        self.train_summaries = []
        self.losses = {}

        self.image = tf.placeholder(tf.float32,
                                    shape=[1, None, None, 3],
                                    name='image')
        self.spatial = tf.placeholder(tf.float32,
                                      shape=[None, 64, 64, 3],
                                      name='sp')
        self.H_boxes = tf.placeholder(tf.float32,
                                      shape=[None, 5],
                                      name='H_boxes')
        self.O_boxes = tf.placeholder(tf.float32,
                                      shape=[None, 5],
                                      name='O_boxes')
        self.gt_class_HO = tf.placeholder(tf.float32,
                                          shape=[None, 600],
                                          name='gt_class_HO')
        self.H_num = tf.placeholder(tf.int32)  # positive nums
        self.image_id = tf.placeholder(tf.int32)
        self.num_classes = 600
        self.compose_num_classes = 600
        self.num_fc = 1024
        self.verb_num_classes = 117
        self.obj_num_classes = 80
        self.scope = 'resnet_v1_101'
        self.stride = [
            16,
        ]
        self.lr = tf.placeholder(tf.float32)
        if tf.__version__ == '1.1.0':
            raise Exception('wrong tensorflow version 1.1.0')
        else:
            from tensorflow.contrib.slim.python.slim.nets.resnet_v1 import resnet_v1_block
            self.blocks = [
                resnet_v1_block('block1', base_depth=64, num_units=3,
                                stride=2),
                resnet_v1_block('block2',
                                base_depth=128,
                                num_units=4,
                                stride=2),
                resnet_v1_block('block3',
                                base_depth=256,
                                num_units=23,
                                stride=1),
                resnet_v1_block('block4',
                                base_depth=512,
                                num_units=3,
                                stride=1),
                resnet_v1_block('block5',
                                base_depth=512,
                                num_units=3,
                                stride=1)
            ]
            if self.model_name.__contains__('unique_weights') or self.model_name.__contains__('_pa3')\
                    or self.model_name.__contains__('_pa4'):
                print("add block6 unique_weights2")
                self.blocks.append(
                    resnet_v1_block('block6',
                                    base_depth=512,
                                    num_units=3,
                                    stride=1))
        """We copy from TIN. calculated by log(1/(n_c/sum(n_c)) c is the category and n_c is 
        the number of positive samples"""
        self.HO_weight = np.array([
            9.192927, 9.778443, 10.338059, 9.164914, 9.075144, 10.045923,
            8.714437, 8.59822, 12.977117, 6.2745423, 11.227917, 6.765012,
            9.436157, 9.56762, 11.0675745, 11.530198, 9.609821, 9.897503,
            6.664475, 6.811699, 6.644726, 9.170454, 13.670264, 3.903943,
            10.556748, 8.814335, 9.519224, 12.753973, 11.590822, 8.278912,
            5.5245695, 9.7286825, 8.997436, 10.699849, 9.601237, 11.965516,
            9.192927, 10.220277, 6.056692, 7.734048, 8.42324, 6.586457,
            6.969533, 10.579222, 13.670264, 4.4531965, 9.326459, 9.288238,
            8.071842, 10.431585, 12.417501, 11.530198, 11.227917, 4.0678477,
            8.854023, 12.571651, 8.225684, 10.996116, 11.0675745, 10.100731,
            7.0376034, 7.463688, 12.571651, 14.363411, 5.4902234, 11.0675745,
            14.363411, 8.45805, 10.269067, 9.820116, 14.363411, 11.272368,
            11.105314, 7.981595, 9.198626, 3.3284247, 14.363411, 12.977117,
            9.300817, 10.032678, 12.571651, 10.114916, 10.471591, 13.264799,
            14.363411, 8.01953, 10.412168, 9.644913, 9.981384, 7.2197933,
            14.363411, 3.1178555, 11.031207, 8.934066, 7.546675, 6.386472,
            12.060826, 8.862153, 9.799063, 12.753973, 12.753973, 10.412168,
            10.8976755, 10.471591, 12.571651, 9.519224, 6.207762, 12.753973,
            6.60636, 6.2896967, 4.5198326, 9.7887, 13.670264, 11.878505,
            11.965516, 8.576513, 11.105314, 9.192927, 11.47304, 11.367679,
            9.275815, 11.367679, 9.944571, 11.590822, 10.451388, 9.511381,
            11.144535, 13.264799, 5.888291, 11.227917, 10.779892, 7.643191,
            11.105314, 9.414651, 11.965516, 14.363411, 12.28397, 9.909063,
            8.94731, 7.0330057, 8.129001, 7.2817025, 9.874775, 9.758241,
            11.105314, 5.0690055, 7.4768796, 10.129305, 9.54313, 13.264799,
            9.699972, 11.878505, 8.260853, 7.1437693, 6.9321113, 6.990665,
            8.8104515, 11.655361, 13.264799, 4.515912, 9.897503, 11.418972,
            8.113436, 8.795067, 10.236277, 12.753973, 14.363411, 9.352776,
            12.417501, 0.6271591, 12.060826, 12.060826, 12.166186, 5.2946343,
            11.318889, 9.8308115, 8.016022, 9.198626, 10.8976755, 13.670264,
            11.105314, 14.363411, 9.653881, 9.503599, 12.753973, 5.80546,
            9.653881, 9.592727, 12.977117, 13.670264, 7.995224, 8.639826,
            12.28397, 6.586876, 10.929424, 13.264799, 8.94731, 6.1026597,
            12.417501, 11.47304, 10.451388, 8.95624, 10.996116, 11.144535,
            11.031207, 13.670264, 13.670264, 6.397866, 7.513285, 9.981384,
            11.367679, 11.590822, 7.4348736, 4.415428, 12.166186, 8.573451,
            12.977117, 9.609821, 8.601359, 9.055143, 11.965516, 11.105314,
            13.264799, 5.8201604, 10.451388, 9.944571, 7.7855496, 14.363411,
            8.5463, 13.670264, 7.9288645, 5.7561946, 9.075144, 9.0701065,
            5.6871653, 11.318889, 10.252538, 9.758241, 9.407584, 13.670264,
            8.570397, 9.326459, 7.488179, 11.798462, 9.897503, 6.7530537,
            4.7828183, 9.519224, 7.6492405, 8.031909, 7.8180614, 4.451856,
            10.045923, 10.83705, 13.264799, 13.670264, 4.5245686, 14.363411,
            10.556748, 10.556748, 14.363411, 13.670264, 14.363411, 8.037262,
            8.59197, 9.738439, 8.652985, 10.045923, 9.400566, 10.9622135,
            11.965516, 10.032678, 5.9017305, 9.738439, 12.977117, 11.105314,
            10.725825, 9.080208, 11.272368, 14.363411, 14.363411, 13.264799,
            6.9279733, 9.153925, 8.075553, 9.126969, 14.363411, 8.903826,
            9.488214, 5.4571533, 10.129305, 10.579222, 12.571651, 11.965516,
            6.237189, 9.428937, 9.618479, 8.620408, 11.590822, 11.655361,
            9.968962, 10.8080635, 10.431585, 14.363411, 3.796231, 12.060826,
            10.302968, 9.551227, 8.75394, 10.579222, 9.944571, 14.363411,
            6.272396, 10.625742, 9.690582, 13.670264, 11.798462, 13.670264,
            11.724354, 9.993963, 8.230013, 9.100721, 10.374427, 7.865129,
            6.514087, 14.363411, 11.031207, 11.655361, 12.166186, 7.419324,
            9.421769, 9.653881, 10.996116, 12.571651, 13.670264, 5.912144,
            9.7887, 8.585759, 8.272101, 11.530198, 8.886948, 5.9870906,
            9.269661, 11.878505, 11.227917, 13.670264, 8.339964, 7.6763024,
            10.471591, 10.451388, 13.670264, 11.185357, 10.032678, 9.313555,
            12.571651, 3.993144, 9.379805, 9.609821, 14.363411, 9.709451,
            8.965248, 10.451388, 7.0609145, 10.579222, 13.264799, 10.49221,
            8.978916, 7.124196, 10.602211, 8.9743395, 7.77862, 8.073695,
            9.644913, 9.339531, 8.272101, 4.794418, 9.016304, 8.012526,
            10.674532, 14.363411, 7.995224, 12.753973, 5.5157638, 8.934066,
            10.779892, 7.930471, 11.724354, 8.85808, 5.9025764, 14.363411,
            12.753973, 12.417501, 8.59197, 10.513264, 10.338059, 14.363411,
            7.7079706, 14.363411, 13.264799, 13.264799, 10.752493, 14.363411,
            14.363411, 13.264799, 12.417501, 13.670264, 6.5661197, 12.977117,
            11.798462, 9.968962, 12.753973, 11.47304, 11.227917, 7.6763024,
            10.779892, 11.185357, 14.363411, 7.369478, 14.363411, 9.944571,
            10.779892, 10.471591, 9.54313, 9.148476, 10.285873, 10.412168,
            12.753973, 14.363411, 6.0308623, 13.670264, 10.725825, 12.977117,
            11.272368, 7.663911, 9.137665, 10.236277, 13.264799, 6.715625,
            10.9622135, 14.363411, 13.264799, 9.575919, 9.080208, 11.878505,
            7.1863923, 9.366199, 8.854023, 9.874775, 8.2857685, 13.670264,
            11.878505, 12.166186, 7.616999, 9.44343, 8.288065, 8.8104515,
            8.347254, 7.4738197, 10.302968, 6.936267, 11.272368, 7.058223,
            5.0138307, 12.753973, 10.173757, 9.863602, 11.318889, 9.54313,
            10.996116, 12.753973, 7.8339925, 7.569945, 7.4427395, 5.560738,
            12.753973, 10.725825, 10.252538, 9.307165, 8.491293, 7.9161053,
            7.8849015, 7.782772, 6.3088884, 8.866243, 9.8308115, 14.363411,
            10.8976755, 5.908519, 10.269067, 9.176025, 9.852551, 9.488214,
            8.90809, 8.537411, 9.653881, 8.662968, 11.965516, 10.143904,
            14.363411, 14.363411, 9.407584, 5.281472, 11.272368, 12.060826,
            14.363411, 7.4135547, 8.920994, 9.618479, 8.891141, 14.363411,
            12.060826, 11.965516, 10.9622135, 10.9622135, 14.363411, 5.658909,
            8.934066, 12.571651, 8.614018, 11.655361, 13.264799, 10.996116,
            13.670264, 8.965248, 9.326459, 11.144535, 14.363411, 6.0517673,
            10.513264, 8.7430105, 10.338059, 13.264799, 6.878481, 9.065094,
            8.87035, 14.363411, 9.92076, 6.5872955, 10.32036, 14.363411,
            9.944571, 11.798462, 10.9622135, 11.031207, 7.652888, 4.334878,
            13.670264, 13.670264, 14.363411, 10.725825, 12.417501, 14.363411,
            13.264799, 11.655361, 10.338059, 13.264799, 12.753973, 8.206432,
            8.916674, 8.59509, 14.363411, 7.376845, 11.798462, 11.530198,
            11.318889, 11.185357, 5.0664344, 11.185357, 9.372978, 10.471591,
            9.6629305, 11.367679, 8.73579, 9.080208, 11.724354, 5.04781,
            7.3777695, 7.065643, 12.571651, 11.724354, 12.166186, 12.166186,
            7.215852, 4.374113, 11.655361, 11.530198, 14.363411, 6.4993753,
            11.031207, 8.344818, 10.513264, 10.032678, 14.363411, 14.363411,
            4.5873594, 12.28397, 13.670264, 12.977117, 10.032678, 9.609821
        ],
                                  dtype='float32').reshape(1, 600)
        num_inst_path = cfg.ROOT_DIR + '/Data/num_inst.npy'
        num_inst = np.load(num_inst_path)
        self.num_inst = num_inst

        verb_to_HO_matrix, obj_to_HO_matrix = get_convert_matrix(
            self.verb_num_classes, self.obj_num_classes)

        self.obj_to_HO_matrix = tf.constant(obj_to_HO_matrix, tf.float32)
        self.verb_to_HO_matrix = tf.constant(verb_to_HO_matrix, tf.float32)
        self.gt_obj_class = tf.cast(
            tf.matmul(
                self.gt_class_HO, self.obj_to_HO_matrix, transpose_b=True) > 0,
            tf.float32)
        self.gt_verb_class = tf.cast(
            tf.matmul(self.gt_class_HO,
                      self.verb_to_HO_matrix,
                      transpose_b=True) > 0, tf.float32)