Example #1
0
    def __init__(self, opt, dataset):
        super(ReinforceCriterion, self).__init__()
        self.dataset = dataset
        self.reward_type = opt.reward_type
        self.bleu = None

        if self.reward_type == 'METEOR':
            from vist_eval.meteor.meteor import Meteor
            self.reward_scorer = Meteor()
        elif self.reward_type == 'CIDEr':
            sys.path.append("cider")
            from pyciderevalcap.ciderD.ciderD import CiderD
            self.reward_scorer = CiderD(df=opt.cached_tokens)
        elif self.reward_type == 'Bleu_4' or self.reward_type == 'Bleu_3':
            from vist_eval.bleu.bleu import Bleu
            self.reward_scorer = Bleu(4)
            self.bleu = int(self.reward_type[-1]) - 1
        elif self.reward_type == 'ROUGE_L':
            from vist_eval.rouge.rouge import Rouge
            self.reward_scorer = Rouge()
        else:
            err_msg = "{} scorer hasn't been implemented".format(
                self.reward_type)
            logging.error(err_msg)
            raise Exception(err_msg)
Example #2
0
    def evaluate(self, album_to_Gts, album_to_Res):
        self.album_to_Res = album_to_Res
        self.album_to_Gts = album_to_Gts

        # =================================================
        # Set up scorers
        # =================================================
        print('setting up scorers...')
        scorers = []
        scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Meteor(), "METEOR"),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr")  # df='VIST/VIST-train-words'
        ]

        # =================================================
        # Compute scores
        # =================================================
        for scorer, method in scorers:
            print('computing %s score ...' % (scorer.method()))
            score, scores = scorer.compute_score(self.album_to_Gts,
                                                 self.album_to_Res)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    self.setEval(sc, m)
                    self.setAlbumToEval(scs, list(self.album_to_Gts.keys()), m)
                    print('%s: %.3f' % (m, sc))
            else:
                self.setEval(score, method)
                self.setAlbumToEval(scores, list(self.album_to_Gts.keys()),
                                    method)
                print('%s: %.3f' % (method, score))

        self.setEvalAlbums()
Example #3
0
class ReinforceCriterion(nn.Module):
    def __init__(self, opt, dataset):
        super(ReinforceCriterion, self).__init__()
        self.dataset = dataset
        self.reward_type = opt.reward_type
        self.bleu = None

        if self.reward_type == 'METEOR':
            from vist_eval.meteor.meteor import Meteor
            self.reward_scorer = Meteor()
        elif self.reward_type == 'CIDEr':
            sys.path.append("cider")
            from pyciderevalcap.ciderD.ciderD import CiderD
            self.reward_scorer = CiderD(df=opt.cached_tokens)
        elif self.reward_type == 'Bleu_4' or self.reward_type == 'Bleu_3':
            from vist_eval.bleu.bleu import Bleu
            self.reward_scorer = Bleu(4)
            self.bleu = int(self.reward_type[-1]) - 1
        elif self.reward_type == 'ROUGE_L':
            from vist_eval.rouge.rouge import Rouge
            self.reward_scorer = Rouge()
        else:
            err_msg = "{} scorer hasn't been implemented".format(
                self.reward_type)
            logging.error(err_msg)
            raise Exception(err_msg)

    def _cal_action_loss(self, log_probs, reward, mask):
        output = -log_probs * reward * mask
        output = torch.sum(output) / torch.sum(mask)
        return output

    def _cal_value_loss(self, reward, baseline, mask):
        output = (reward - baseline).pow(2) * mask
        output = torch.sum(output) / torch.sum(mask)
        return output

    def forward(self, seq, seq_log_probs, baseline, index, rewards=None):
        '''
        :param seq: (batch_size, 5, seq_length)
        :param seq_log_probs: (batch_size, 5, seq_length)
        :param baseline: (batch_size, 5, seq_length)
        :param indexes: (batch_size,)
        :param rewards: (batch_size, 5, seq_length)
        :return:
        '''
        if rewards is None:
            # compute the reward
            sents = utils.decode_story(self.dataset.get_vocab(), seq)

            rewards = []
            batch_size = seq.size(0)
            for i, story in enumerate(sents):
                vid, _ = self.dataset.get_id(index[i])
                GT_story = self.dataset.get_GT(index[i])
                result = {vid: [story]}
                gt = {vid: [GT_story]}
                score, _ = self.reward_scorer.compute_score(gt, result)
                if self.bleu is not None:
                    rewards.append(score[self.bleu])
                else:
                    rewards.append(score)
            rewards = torch.FloatTensor(rewards)  # (batch_size,)
            avg_reward = rewards.mean()
            rewards = Variable(rewards.view(batch_size, 1,
                                            1).expand_as(seq)).cuda()
        else:
            avg_reward = rewards.mean()
            rewards = rewards.view(-1, 5, 1)

        # get the mask
        mask = (seq > 0).float(
        )  # its size is supposed to be (batch_size, 5, seq_length)
        if mask.size(2) > 1:
            mask = torch.cat([
                mask.new(mask.size(0), mask.size(1), 1).fill_(1),
                mask[:, :, :-1]
            ], 2).contiguous()
        else:
            mask.fill_(1)
        mask = Variable(mask)

        # compute the loss
        advantage = Variable(rewards.data - baseline.data)
        value_loss = self._cal_value_loss(rewards, baseline, mask)
        action_loss = self._cal_action_loss(seq_log_probs, advantage, mask)

        return action_loss + value_loss, avg_reward
Example #4
0
	def evaluate(self, measure=None):
		"""
		measure is a subset of ['bleu', 'meteor', 'rouge', 'cider']
		if measure is None, we will apply all the above.
		"""

		# story_img_ids -> pred story str
		stimgids_to_Res = {item['stimgids']: [item['pred_story_str'].encode('ascii', 'ignore').decode('ascii')]
						for item in self.preds }

		# story_img_ids -> gt storie str(s)
		stimgids_to_stories = {}
		for story in self.vist_sis.stories:
			story_img_ids = '_'.join([str(img_id) for img_id in story['img_ids']])
			if story_img_ids in stimgids_to_stories:
				stimgids_to_stories[story_img_ids] += [story]
			else:
				stimgids_to_stories[story_img_ids] = [story]

		stimgids_to_Gts = {}
		for stimgids in stimgids_to_Res.keys():
			gd_story_strs = []
			related_stories = stimgids_to_stories[stimgids]
			for story in related_stories:
				gd_sent_ids = self.vist_sis.Stories[story['id']]['sent_ids']
				gd_story_str = ' '.join([self.vist_sis.Sents[sent_id]['text'] for sent_id in gd_sent_ids])
				gd_story_str = gd_story_str.encode('ascii', 'ignore').decode('ascii')  # ignore some weird token
				gd_story_strs += [gd_story_str]
			stimgids_to_Gts[stimgids] = gd_story_strs

		# tokenize
		# print 'tokenization ... '
		# tokenizer = PTBTokenizer()
		# self.stimgids_to_Res = tokenizer.tokenize(stimgids_to_Res)
		# self.stimgids_to_Gts = tokenizer.tokenize(stimgids_to_Gts)
		self.stimgids_to_Res = stimgids_to_Res
		self.stimgids_to_Gts = stimgids_to_Gts

		# =================================================
		# Set up scorers
		# =================================================
		print('setting up scorers...')
		scorers = []
		if not measure:
			scorers = [
				(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
				(Meteor(),"METEOR"),
				(Rouge(), "ROUGE_L"),
				(Cider(), "CIDEr")
			]
		else:
			if 'bleu' in measure:
				scorers += [(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"])]
			if 'meteor' in measure:
				scorers += [(Meteor(),"METEOR")]
			if 'rouge' in measure:
				scorers += [(Rouge(), "ROUGE_L")]
			if 'cider' in measure:
				scorers += [(Cider(), "CIDEr")]

		# =================================================
		# Compute scores
		# =================================================
		for scorer, method in scorers:
			print('computing %s score ...' % (scorer.method()))
			score, scores = scorer.compute_score(self.stimgids_to_Gts, self.stimgids_to_Res)
			if type(method) == list:
				for sc, scs, m in zip(score, scores, method):
					self.setEval(sc, m)
					self.setStimgidsToEval(scs, self.stimgids_to_Gts.keys(), m)
					print('%s: %.3f' % (m, sc))
			else:
				self.setEval(score, method)
				self.setStimgidsToEval(scores, self.stimgids_to_Gts.keys(), method)
				print('%s: %.3f' % (method, score))

		self.setEvalStimgids()