def test_net_data_api1(sess, net, output_dir, h_box, o_box, o_cls, h_score, o_score, image_id): detection = {} verb_to_HO_matrix, obj_to_HO_matrix = get_convert_matrix() # prediction_HO = net.test_image_HO(sess, im_orig, blobs) # timers ones = np.ones([1, 600], np.float32) _t = {'im_detect': Timer(), 'misc': Timer()} last_img_id = -1 count = 0 fuse_res = tf.constant(0) obj_scores = tf.constant(0) objid = tf.constant(0) obj_scores = tf.constant(0) _t['im_detect'].tic() while True: _t['im_detect'].tic() from tensorflow.python.framework.errors_impl import InvalidArgumentError try: pH, pO, pSp, pVerbs, pSpHO, pFuse, f_obj_score, f_obj_cls, _h_box, _o_box, _o_cls, _h_score, _o_score, _image_id = sess.run([ net.predictions["cls_prob_H"] if 'cls_prob_H' in net.predictions else h_box, net.predictions["cls_prob_O"] if 'cls_prob_O' in net.predictions else h_box, net.predictions["cls_prob_sp"] if 'cls_prob_sp' in net.predictions else h_box, net.predictions["cls_prob_hoi"] if 'cls_prob_hoi' in net.predictions else h_box, net.predictions["cls_prob_spverbs"] if 'cls_prob_spverbs' in net.predictions else h_box, fuse_res if 'cls_prob_sp' in net.predictions else h_box, obj_scores, objid, h_box, o_box, o_cls, h_score, o_score, image_id]) # print(pFuse.shape, f_obj_score.shape, f_obj_cls.shape) except InvalidArgumentError as e: # cls_prob_HO = np.zeros(shape=[blobs['sp'].shape[0], self.num_classes]) raise e except tf.errors.OutOfRangeError: print('END') break # if net.model_name.__contains__('cosine'): # temp = [[_h_box[i], _o_box[i], _o_cls[i], 0, _h_score[i], _o_score[i], pH[i], pO[i], pSp[i], pVerbs[i], pSpHO[i], pFuse[i], f_obj_score[i], f_obj_cls[i]] for i in range(len(_h_box))] # else: temp = [[_h_box[i], _o_box[i], _o_cls[i], 0, _h_score[i], _o_score[i], pH[i], pO[i], pSp[i], pVerbs[i], pSpHO[i]] for i in range(len(_h_box))] # detection[_image_id] = temp if _image_id in detection: detection[_image_id].extend(temp) else: detection[_image_id] = temp _t['im_detect'].toc() count += 1 print('\rmodel: {} im_detect: {:d}/{:d} {:d}, {:.3f}s'.format(net.model_name, count, 10566, _image_id, _t['im_detect'].average_time), end='', flush=True) # TODO remove # pickle.dump(detection, open('test_new.pkl', 'wb')) pickle.dump(detection, open(output_dir, "wb")) del detection import gc gc.collect()
def reset_classes(self): from ult.tools import get_convert_matrix verb_to_HO_matrix, obj_to_HO_matrix = get_convert_matrix(self.verb_num_classes, self.obj_num_classes) self.obj_to_HO_matrix = tf.constant(obj_to_HO_matrix, tf.float32) self.verb_to_HO_matrix = tf.constant(verb_to_HO_matrix, tf.float32) self.gt_obj_class = tf.cast(tf.matmul(self.gt_class_HO, self.obj_to_HO_matrix, transpose_b=True) > 0, tf.float32) self.gt_verb_class = tf.cast(tf.matmul(self.gt_class_HO, self.verb_to_HO_matrix, transpose_b=True) > 0, tf.float32)
# print(preds, element[4], element[5]) temp.append(preds[begin - 1 + i] * element[4] * element[5]) total.append(temp) score.append(preds[begin - 1 + i] * element[4] * element[5]) idx = np.argsort(score, axis=0)[::-1] for i_idx in range(min(len(idx), 19999)): all_boxes.append(total[idx[i_idx]]) savefile = HICO_dir + 'detections_' + str(classid).zfill(2) + '.mat' # print('length:', classid, len(all_boxes)) sio.savemat(savefile, {'all_boxes': all_boxes}) return all_boxes verb_to_HO_matrix, obj_to_HO_matrix = get_convert_matrix() hoi_2_obj = {} for i in range(600): for j in range(80): if obj_to_HO_matrix[j][i] > 0: hoi_2_obj[i] = j def obtain_fuse_preds(element, fuse_type): preds = element[3] if fuse_type != 'preds': pH = element[6] pO = element[7] pSp = element[8] pHoi = element[9] if fuse_type == 'preds':
def __init__(self, model_name): self.model_name = model_name self.visualize = {} self.test_visualize = {} self.intermediate = {} self.predictions = {} self.score_summaries = {} self.event_summaries = {} self.train_summaries = [] self.losses = {} self.image = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='image') self.spatial = tf.placeholder(tf.float32, shape=[None, 64, 64, 3], name='sp') self.H_boxes = tf.placeholder(tf.float32, shape=[None, 5], name='H_boxes') self.O_boxes = tf.placeholder(tf.float32, shape=[None, 5], name='O_boxes') self.gt_class_HO = tf.placeholder(tf.float32, shape=[None, 600], name='gt_class_HO') self.H_num = tf.placeholder(tf.int32) # positive nums self.image_id = tf.placeholder(tf.int32) self.num_classes = 600 self.compose_num_classes = 600 self.num_fc = 1024 self.verb_num_classes = 117 self.obj_num_classes = 80 self.scope = 'resnet_v1_101' self.stride = [ 16, ] self.lr = tf.placeholder(tf.float32) if tf.__version__ == '1.1.0': raise Exception('wrong tensorflow version 1.1.0') else: from tensorflow.contrib.slim.python.slim.nets.resnet_v1 import resnet_v1_block self.blocks = [ resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), resnet_v1_block('block3', base_depth=256, num_units=23, stride=1), resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), resnet_v1_block('block5', base_depth=512, num_units=3, stride=1) ] if self.model_name.__contains__('unique_weights') or self.model_name.__contains__('_pa3')\ or self.model_name.__contains__('_pa4'): print("add block6 unique_weights2") self.blocks.append( resnet_v1_block('block6', base_depth=512, num_units=3, stride=1)) """We copy from TIN. calculated by log(1/(n_c/sum(n_c)) c is the category and n_c is the number of positive samples""" self.HO_weight = np.array([ 9.192927, 9.778443, 10.338059, 9.164914, 9.075144, 10.045923, 8.714437, 8.59822, 12.977117, 6.2745423, 11.227917, 6.765012, 9.436157, 9.56762, 11.0675745, 11.530198, 9.609821, 9.897503, 6.664475, 6.811699, 6.644726, 9.170454, 13.670264, 3.903943, 10.556748, 8.814335, 9.519224, 12.753973, 11.590822, 8.278912, 5.5245695, 9.7286825, 8.997436, 10.699849, 9.601237, 11.965516, 9.192927, 10.220277, 6.056692, 7.734048, 8.42324, 6.586457, 6.969533, 10.579222, 13.670264, 4.4531965, 9.326459, 9.288238, 8.071842, 10.431585, 12.417501, 11.530198, 11.227917, 4.0678477, 8.854023, 12.571651, 8.225684, 10.996116, 11.0675745, 10.100731, 7.0376034, 7.463688, 12.571651, 14.363411, 5.4902234, 11.0675745, 14.363411, 8.45805, 10.269067, 9.820116, 14.363411, 11.272368, 11.105314, 7.981595, 9.198626, 3.3284247, 14.363411, 12.977117, 9.300817, 10.032678, 12.571651, 10.114916, 10.471591, 13.264799, 14.363411, 8.01953, 10.412168, 9.644913, 9.981384, 7.2197933, 14.363411, 3.1178555, 11.031207, 8.934066, 7.546675, 6.386472, 12.060826, 8.862153, 9.799063, 12.753973, 12.753973, 10.412168, 10.8976755, 10.471591, 12.571651, 9.519224, 6.207762, 12.753973, 6.60636, 6.2896967, 4.5198326, 9.7887, 13.670264, 11.878505, 11.965516, 8.576513, 11.105314, 9.192927, 11.47304, 11.367679, 9.275815, 11.367679, 9.944571, 11.590822, 10.451388, 9.511381, 11.144535, 13.264799, 5.888291, 11.227917, 10.779892, 7.643191, 11.105314, 9.414651, 11.965516, 14.363411, 12.28397, 9.909063, 8.94731, 7.0330057, 8.129001, 7.2817025, 9.874775, 9.758241, 11.105314, 5.0690055, 7.4768796, 10.129305, 9.54313, 13.264799, 9.699972, 11.878505, 8.260853, 7.1437693, 6.9321113, 6.990665, 8.8104515, 11.655361, 13.264799, 4.515912, 9.897503, 11.418972, 8.113436, 8.795067, 10.236277, 12.753973, 14.363411, 9.352776, 12.417501, 0.6271591, 12.060826, 12.060826, 12.166186, 5.2946343, 11.318889, 9.8308115, 8.016022, 9.198626, 10.8976755, 13.670264, 11.105314, 14.363411, 9.653881, 9.503599, 12.753973, 5.80546, 9.653881, 9.592727, 12.977117, 13.670264, 7.995224, 8.639826, 12.28397, 6.586876, 10.929424, 13.264799, 8.94731, 6.1026597, 12.417501, 11.47304, 10.451388, 8.95624, 10.996116, 11.144535, 11.031207, 13.670264, 13.670264, 6.397866, 7.513285, 9.981384, 11.367679, 11.590822, 7.4348736, 4.415428, 12.166186, 8.573451, 12.977117, 9.609821, 8.601359, 9.055143, 11.965516, 11.105314, 13.264799, 5.8201604, 10.451388, 9.944571, 7.7855496, 14.363411, 8.5463, 13.670264, 7.9288645, 5.7561946, 9.075144, 9.0701065, 5.6871653, 11.318889, 10.252538, 9.758241, 9.407584, 13.670264, 8.570397, 9.326459, 7.488179, 11.798462, 9.897503, 6.7530537, 4.7828183, 9.519224, 7.6492405, 8.031909, 7.8180614, 4.451856, 10.045923, 10.83705, 13.264799, 13.670264, 4.5245686, 14.363411, 10.556748, 10.556748, 14.363411, 13.670264, 14.363411, 8.037262, 8.59197, 9.738439, 8.652985, 10.045923, 9.400566, 10.9622135, 11.965516, 10.032678, 5.9017305, 9.738439, 12.977117, 11.105314, 10.725825, 9.080208, 11.272368, 14.363411, 14.363411, 13.264799, 6.9279733, 9.153925, 8.075553, 9.126969, 14.363411, 8.903826, 9.488214, 5.4571533, 10.129305, 10.579222, 12.571651, 11.965516, 6.237189, 9.428937, 9.618479, 8.620408, 11.590822, 11.655361, 9.968962, 10.8080635, 10.431585, 14.363411, 3.796231, 12.060826, 10.302968, 9.551227, 8.75394, 10.579222, 9.944571, 14.363411, 6.272396, 10.625742, 9.690582, 13.670264, 11.798462, 13.670264, 11.724354, 9.993963, 8.230013, 9.100721, 10.374427, 7.865129, 6.514087, 14.363411, 11.031207, 11.655361, 12.166186, 7.419324, 9.421769, 9.653881, 10.996116, 12.571651, 13.670264, 5.912144, 9.7887, 8.585759, 8.272101, 11.530198, 8.886948, 5.9870906, 9.269661, 11.878505, 11.227917, 13.670264, 8.339964, 7.6763024, 10.471591, 10.451388, 13.670264, 11.185357, 10.032678, 9.313555, 12.571651, 3.993144, 9.379805, 9.609821, 14.363411, 9.709451, 8.965248, 10.451388, 7.0609145, 10.579222, 13.264799, 10.49221, 8.978916, 7.124196, 10.602211, 8.9743395, 7.77862, 8.073695, 9.644913, 9.339531, 8.272101, 4.794418, 9.016304, 8.012526, 10.674532, 14.363411, 7.995224, 12.753973, 5.5157638, 8.934066, 10.779892, 7.930471, 11.724354, 8.85808, 5.9025764, 14.363411, 12.753973, 12.417501, 8.59197, 10.513264, 10.338059, 14.363411, 7.7079706, 14.363411, 13.264799, 13.264799, 10.752493, 14.363411, 14.363411, 13.264799, 12.417501, 13.670264, 6.5661197, 12.977117, 11.798462, 9.968962, 12.753973, 11.47304, 11.227917, 7.6763024, 10.779892, 11.185357, 14.363411, 7.369478, 14.363411, 9.944571, 10.779892, 10.471591, 9.54313, 9.148476, 10.285873, 10.412168, 12.753973, 14.363411, 6.0308623, 13.670264, 10.725825, 12.977117, 11.272368, 7.663911, 9.137665, 10.236277, 13.264799, 6.715625, 10.9622135, 14.363411, 13.264799, 9.575919, 9.080208, 11.878505, 7.1863923, 9.366199, 8.854023, 9.874775, 8.2857685, 13.670264, 11.878505, 12.166186, 7.616999, 9.44343, 8.288065, 8.8104515, 8.347254, 7.4738197, 10.302968, 6.936267, 11.272368, 7.058223, 5.0138307, 12.753973, 10.173757, 9.863602, 11.318889, 9.54313, 10.996116, 12.753973, 7.8339925, 7.569945, 7.4427395, 5.560738, 12.753973, 10.725825, 10.252538, 9.307165, 8.491293, 7.9161053, 7.8849015, 7.782772, 6.3088884, 8.866243, 9.8308115, 14.363411, 10.8976755, 5.908519, 10.269067, 9.176025, 9.852551, 9.488214, 8.90809, 8.537411, 9.653881, 8.662968, 11.965516, 10.143904, 14.363411, 14.363411, 9.407584, 5.281472, 11.272368, 12.060826, 14.363411, 7.4135547, 8.920994, 9.618479, 8.891141, 14.363411, 12.060826, 11.965516, 10.9622135, 10.9622135, 14.363411, 5.658909, 8.934066, 12.571651, 8.614018, 11.655361, 13.264799, 10.996116, 13.670264, 8.965248, 9.326459, 11.144535, 14.363411, 6.0517673, 10.513264, 8.7430105, 10.338059, 13.264799, 6.878481, 9.065094, 8.87035, 14.363411, 9.92076, 6.5872955, 10.32036, 14.363411, 9.944571, 11.798462, 10.9622135, 11.031207, 7.652888, 4.334878, 13.670264, 13.670264, 14.363411, 10.725825, 12.417501, 14.363411, 13.264799, 11.655361, 10.338059, 13.264799, 12.753973, 8.206432, 8.916674, 8.59509, 14.363411, 7.376845, 11.798462, 11.530198, 11.318889, 11.185357, 5.0664344, 11.185357, 9.372978, 10.471591, 9.6629305, 11.367679, 8.73579, 9.080208, 11.724354, 5.04781, 7.3777695, 7.065643, 12.571651, 11.724354, 12.166186, 12.166186, 7.215852, 4.374113, 11.655361, 11.530198, 14.363411, 6.4993753, 11.031207, 8.344818, 10.513264, 10.032678, 14.363411, 14.363411, 4.5873594, 12.28397, 13.670264, 12.977117, 10.032678, 9.609821 ], dtype='float32').reshape(1, 600) num_inst_path = cfg.ROOT_DIR + '/Data/num_inst.npy' num_inst = np.load(num_inst_path) self.num_inst = num_inst verb_to_HO_matrix, obj_to_HO_matrix = get_convert_matrix( self.verb_num_classes, self.obj_num_classes) self.obj_to_HO_matrix = tf.constant(obj_to_HO_matrix, tf.float32) self.verb_to_HO_matrix = tf.constant(verb_to_HO_matrix, tf.float32) self.gt_obj_class = tf.cast( tf.matmul( self.gt_class_HO, self.obj_to_HO_matrix, transpose_b=True) > 0, tf.float32) self.gt_verb_class = tf.cast( tf.matmul(self.gt_class_HO, self.verb_to_HO_matrix, transpose_b=True) > 0, tf.float32)