class VanillaModel(BaseVQAModel): def __init__(self, ckpt_file='model/kprestval_VQA-BaseNorm/model.ckpt-26000'): BaseVQAModel.__init__(self) self.top_k = 2 self.g = tf.Graph() self.ckpt_file = ckpt_file from models.vqa_base import BaseModel from vqa_config import ModelConfig config = ModelConfig() self._subset = 'test' self._year = 2015 self.name = ' ------- DeeperLSTM ------- ' with self.g.as_default(): self.sess = tf.Session() self.model = BaseModel(config, phase='test') self.model.build() vars = tf.trainable_variables() self.saver = tf.train.Saver(var_list=vars) self.saver.restore(self.sess, ckpt_file) def _load_image(self, image_id): FEAT_ROOT = '/usr/data/fl302/data/VQA/ResNet152/resnet_res5c' filename = '%s%d/COCO_%s%d_%012d.jpg' % ( self._subset, self._year, self._subset, self._year, image_id) f = np.load(os.path.join(FEAT_ROOT, filename + '.npz'))['x'] f1 = np.mean(np.mean(f.transpose((1, 2, 0)), axis=0), axis=0) # f2 = f.reshape([2048, -1]).mean(axis=1) # import pdb # pdb.set_trace() return f1[np.newaxis, ::]
class VQAWrapper(object): def __init__(self, g, sess): from models.vqa_base import BaseModel from vqa_config import ModelConfig config = ModelConfig() self.sess = sess ckpt_file = 'model/kprestval_VQA-BaseNorm/model.ckpt-26000' with g.as_default(): self.sess = tf.Session() self.model = BaseModel(config, phase='test') with tf.variable_scope('VQA'): self.model.build() # vars = tf.trainable_variables() vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='VQA') vars_dict = {v.name.replace('VQA/', '').split(':')[0]: v for v in vars} self.saver = tf.train.Saver(var_list=vars_dict) self.saver.restore(self.sess, ckpt_file) def get_scores(self, sampled, image, top_ans_id): pathes = [] for p in sampled: if p[-1] == END_TOKEN: pathes.append(p[1:-1]) # remove start end token else: pathes.append(p[1:]) # remove start end token num_unk = len(sampled) images_aug = np.tile(image, [num_unk, 1]) # put to arrays arr, arr_len = put_to_array(pathes) scores = self.model.inference(self.sess, [images_aug, arr, arr_len]) vqa_scores = scores[:, top_ans_id].flatten() return vqa_scores
class VanillaModel(BaseVQAModel): def __init__(self, ckpt_file='model/kprestval_VQA-BaseNorm/model.ckpt-26000'): BaseVQAModel.__init__(self) self.top_k = 2 self.g = tf.Graph() self.ckpt_file = ckpt_file from models.vqa_base import BaseModel from vqa_config import ModelConfig config = ModelConfig() self.name = ' ------- DeeperLSTM ------- ' with self.g.as_default(): self.sess = tf.Session() self.model = BaseModel(config, phase='test') self.model.build() vars = tf.trainable_variables() self.saver = tf.train.Saver(var_list=vars) self.saver.restore(self.sess, ckpt_file) self._init_image_cache() def _init_image_cache(self): from util import load_hdf5 d = load_hdf5('data/res152_std_mscoco_kptest.data') # d = load_hdf5('data/res152_std_mscoco_kpval.data') image_ids = d['image_ids'] self.im_feats = d['features'] image_id2index = { image_id: idx for idx, image_id in enumerate(image_ids) } self.image_id2index = image_id2index def _load_image(self, image_id): idx = self.image_id2index[image_id] return self.im_feats[idx][np.newaxis, :]
class VQARewards(object): def __init__(self, ckpt_file='', use_dis_reward=False, use_attention_model=False): self.g = tf.Graph() self.ckpt_file = ckpt_file self.use_attention_model = use_attention_model from models.vqa_base import BaseModel from vqa_config import ModelConfig config = ModelConfig() self.ans2id = AnswerTokenToTopAnswer() self.use_dis_reward = use_dis_reward with self.g.as_default(): self.sess = tf.Session() if self.use_attention_model: self.model = AttentionModel(config, phase='test') self.model.build() else: self.model = BaseModel(config, phase='test') self.model.build() vars = tf.trainable_variables() self.saver = tf.train.Saver(var_list=vars) self.saver.restore(self.sess, ckpt_file) def process_answers(self, ans, ans_len): ans_pathes = _parse_gt_questions(ans, ans_len) return self.ans2id.get_top_answer(ans_pathes) def get_reward(self, sampled, inputs): if len(inputs) == 3: images, ans, ans_len = inputs top_ans_ids = self.process_answers(ans, ans_len) else: assert (len(inputs) == 4) images, ans, ans_len, top_ans_ids = inputs images_aug = [] top_ans_ids_aug = [] answer_aug = [] answer_len_aug = [] pathes = [] for _idx, ps in enumerate(sampled): for p in ps: if p[-1] == END_TOKEN: pathes.append(p[1:-1]) # remove start end token else: pathes.append(p[1:]) # remove start end token images_aug.append(images[_idx][np.newaxis, :]) answer_aug.append(ans[_idx][np.newaxis, :]) answer_len_aug.append(ans_len[_idx]) top_ans_ids_aug.append(top_ans_ids[_idx]) # put to arrays arr, arr_len = put_to_array(pathes) images_aug = np.concatenate(images_aug) answer_aug = np.concatenate(answer_aug).astype(np.int32) top_ans_ids_aug = np.array(top_ans_ids_aug) answer_len_aug = np.array(answer_len_aug, dtype=np.int32) # run inference in VQA scores = self.model.inference(self.sess, [images_aug, arr, arr_len]) if self.use_dis_reward: vqa_scores = np.require(scores.argmax(axis=1) == top_ans_ids_aug, np.float32) else: _this_batch_size = scores.shape[0] vqa_scores = scores[np.arange(_this_batch_size), top_ans_ids_aug] is_valid = top_ans_ids_aug != 2000 return vqa_scores, [images_aug, answer_aug, answer_len_aug, is_valid]