class VQAWrapper(object): def __init__(self, g, sess): from models.vqa_base import BaseModel from vqa_config import ModelConfig config = ModelConfig() self.sess = sess ckpt_file = 'model/kprestval_VQA-BaseNorm/model.ckpt-26000' with g.as_default(): self.sess = tf.Session() self.model = BaseModel(config, phase='test') with tf.variable_scope('VQA'): self.model.build() # vars = tf.trainable_variables() vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='VQA') vars_dict = {v.name.replace('VQA/', '').split(':')[0]: v for v in vars} self.saver = tf.train.Saver(var_list=vars_dict) self.saver.restore(self.sess, ckpt_file) def get_scores(self, sampled, image, top_ans_id): pathes = [] for p in sampled: if p[-1] == END_TOKEN: pathes.append(p[1:-1]) # remove start end token else: pathes.append(p[1:]) # remove start end token num_unk = len(sampled) images_aug = np.tile(image, [num_unk, 1]) # put to arrays arr, arr_len = put_to_array(pathes) scores = self.model.inference(self.sess, [images_aug, arr, arr_len]) vqa_scores = scores[:, top_ans_id].flatten() return vqa_scores
class VQARewards(object): def __init__(self, ckpt_file='', use_dis_reward=False, use_attention_model=False): self.g = tf.Graph() self.ckpt_file = ckpt_file self.use_attention_model = use_attention_model from models.vqa_base import BaseModel from vqa_config import ModelConfig config = ModelConfig() self.ans2id = AnswerTokenToTopAnswer() self.use_dis_reward = use_dis_reward with self.g.as_default(): self.sess = tf.Session() if self.use_attention_model: self.model = AttentionModel(config, phase='test') self.model.build() else: self.model = BaseModel(config, phase='test') self.model.build() vars = tf.trainable_variables() self.saver = tf.train.Saver(var_list=vars) self.saver.restore(self.sess, ckpt_file) def process_answers(self, ans, ans_len): ans_pathes = _parse_gt_questions(ans, ans_len) return self.ans2id.get_top_answer(ans_pathes) def get_reward(self, sampled, inputs): if len(inputs) == 3: images, ans, ans_len = inputs top_ans_ids = self.process_answers(ans, ans_len) else: assert (len(inputs) == 4) images, ans, ans_len, top_ans_ids = inputs images_aug = [] top_ans_ids_aug = [] answer_aug = [] answer_len_aug = [] pathes = [] for _idx, ps in enumerate(sampled): for p in ps: if p[-1] == END_TOKEN: pathes.append(p[1:-1]) # remove start end token else: pathes.append(p[1:]) # remove start end token images_aug.append(images[_idx][np.newaxis, :]) answer_aug.append(ans[_idx][np.newaxis, :]) answer_len_aug.append(ans_len[_idx]) top_ans_ids_aug.append(top_ans_ids[_idx]) # put to arrays arr, arr_len = put_to_array(pathes) images_aug = np.concatenate(images_aug) answer_aug = np.concatenate(answer_aug).astype(np.int32) top_ans_ids_aug = np.array(top_ans_ids_aug) answer_len_aug = np.array(answer_len_aug, dtype=np.int32) # run inference in VQA scores = self.model.inference(self.sess, [images_aug, arr, arr_len]) if self.use_dis_reward: vqa_scores = np.require(scores.argmax(axis=1) == top_ans_ids_aug, np.float32) else: _this_batch_size = scores.shape[0] vqa_scores = scores[np.arange(_this_batch_size), top_ans_ids_aug] is_valid = top_ans_ids_aug != 2000 return vqa_scores, [images_aug, answer_aug, answer_len_aug, is_valid]