Exemplo n.º 1
0
class AQMQuestioner(Agent):
    def __init__(self, qEncoderParam=None, qDecoderParam=None, aEncoderParam=None, aDecoderParam=None, imgFeatureSize=0,
                 verbose=1):
        super(AQMQuestioner, self).__init__()
        if qEncoderParam is not None and qDecoderParam is not None:
            self.questioner = Questioner(qEncoderParam, qDecoderParam, imgFeatureSize, verbose)
            self.startToken = qEncoderParam['startToken']
            self.endToken = qEncoderParam['endToken']
        if aEncoderParam is not None and aDecoderParam is not None:
            self.appAnswerer = Answerer(aEncoderParam, aDecoderParam, verbose)
        self.questions = [] # for Guesser
        self.quesLens = [] # for Guesser
        self.answers = [] # for Guesser
        self.ansLens = []
        self.image = None
        self.caption = None
        self.captionLens = None
        self.dataset = None
        self.dataset_split = None
        self.logSoftmax = nn.LogSoftmax(dim=0)
    
    def setQuestioner(self, qBot):
        self.questioner = qBot
        self.startToken = qBot.encoder.startToken
        self.endToken = qBot.encoder.endToken
    
    def setAppAnswerer(self, aBot):
        self.appAnswerer = aBot
        self.startToken = aBot.encoder.startToken
        self.endToken = aBot.encoder.endToken
    
    def setData(self, dataset, dataset_split):
        self.dataset = dataset
        self.dataset_split = dataset_split
        # (Num_images,), each element with index c is p(a_j|image_c, q_j, a_(1:j-1), q_(1:j-1))
    
    def setPrior(self, predFeature, gtFeatures, lda=1):
        '''
        Arguments:
            (Image,) : Score for each image
            lda : lambda for scaling score
        '''
        diff = predFeature - gtFeatures
        score = -torch.sum(diff * diff, 1).sqrt()
        self.prior = lda * self.logSoftmax(score * 10).data

    def reset(self):
        self.questioner.reset()
        self.appAnswerer.reset()
        self.questions = [] # for Guesser
        self.quesLens = [] # for Guesser
        self.answers = [] # for Guesser
        self.ansLens = [] # for Guesser
        self.image = None
        self.caption = None
        self.captionLens = None
        self.prior = None

    def observe(self, round, ques=None, ans=None, image=None, **kwargs):
        """
        Round == -1:
        A observe image & caption
        Q observe caption
        Round > -1:
        A observe GT Q&A, forward and get loss
        Q observe GT Question, forward and get loss, then observe A
        Feature net predict and compute loss with mse

        How A work:
        First round: 
        embed image, caption, question & answer
        embed fact(first round is caption, following is QA pairs)
        embed un-embedded question(the new one)
        embed fact, new q and image to dialog RNN
        get state 0, 1: question RNN cell & hidden state, 2: DialogRNN hidden state

        encodeStates & caption->decoder
        Following:
        encodeState & last answer -> decoder
        
        Guessor need P(a|c) -> reset Q -> set caption C & image 0 -> Generate answer as well as prob
        Can sample from the return matrix, and use sum to compuate the log(P(a_j|c)), for speicifc round 
        of answer

        In our code (Eval):
        A, Q observe generated caption and image
        Guesser in Q has belief
        Q generate question
        A observe
        A generate answer
        Q observe
        Guess in Q has belief
        """

        if ques is not None:
            self.questions.append(ques)
            self.quesLens.append(kwargs['quesLens'])
        if ans is not None:
            self.answers.append(ans)
            self.ansLens.append(kwargs['ansLens'])
        if image is not None:
            self.image = image
        if 'captionLens' in kwargs:
            self.captionLens = kwargs['captionLens']
        if 'caption' in kwargs:
            self.caption = kwargs['caption']

        self.questioner.observe(round, ques, ans=ans, **kwargs)
        if self.training:
            self.appAnswerer.observe(round, ans, **kwargs, image=image, ques=ques)

    def p_a(self, candImgIdxs, candQues, candQuesLens, ansParams=None, batchSize=20, returnCandAnswers=False,
            randAnswers=None, randAnsLens=None, noHistory=False, numA=None):
        '''
        Calculate p_reg(a|c,q_t,history) (Normalized)

        Return:
            probArray : (quesSize, imageSize, ansSize), probArray[i, j, k] = P(a(q(i), c(k))|q(i), c(j))
        '''

        assert len(self.answers) == len(self.questions), "Didn't observe full QA pairs!"
        assert self.dataset is not None, "Please set dataset!"
        assert len(candQues) == len(candQuesLens)
        # assert len(candImgIdxs) % batchSize == 0 or batchSize > len(candImgIdxs)
        if not len(candImgIdxs) % batchSize == 0 or batchSize > len(candImgIdxs):
            batchSize = len(candImgIdxs)

        # (quesSize, 1), to fit the observe() function
        candQuesLens = candQuesLens.unsqueeze(1)

        original_split = self.dataset.split
        self.dataset.split = self.dataset_split
        dataloader = DataLoader(
            self.dataset,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            collate_fn=self.dataset.collate_fn)

        imgList = []
        batchSize = len(candImgIdxs) if len(candImgIdxs) < batchSize else batchSize
        iterTime = len(candImgIdxs) // batchSize
        candImgIdxs = torch.sort(candImgIdxs, dim=0)[0].cpu().numpy()
        imgidxIdx = 0 # To index candImgIdxs

        # assert iterTime == 1, "Deepcopy bug is not fixed!"
        myParams = []

        if numA is None:
            numA = batchSize
            aBeamSize = 1
        elif batchSize < numA:
            aBeamSize = numA // batchSize + (numA % batchSize > 0)
        elif batchSize >= numA:
            aBeamSize = 1

        if ansParams is None or noHistory:
            for idx, batch in enumerate(dataloader):
                if candImgIdxs[imgidxIdx] == idx:
                    if self.dataset.useGPU:
                        batch = {key: v.cuda() for key, v in batch.items() \
                                                        if hasattr(v, 'cuda')}
                    else:
                        batch = {key: v.contiguous() for key, v in batch.items() \
                                                        if hasattr(v, 'cuda')}

                    imgList.append(Variable(batch['img_feat'], volatile=True)) # (batch_size, img_embed)
                    imgidxIdx += 1
                    if imgidxIdx == len(candImgIdxs):
                        break
            imgList = torch.cat(imgList, 0)
        else:
            maxIdx = 0
            for bs, params in ansParams:
                maxIdx += bs
                selList = []
                while imgidxIdx < len(candImgIdxs) and candImgIdxs[imgidxIdx] < maxIdx:
                    selList.append(candImgIdxs[imgidxIdx] - maxIdx + bs)
                    imgidxIdx += 1
                if len(selList) > 0:
                    myParams.append(self.appAnswerer.selectParam(params, np.array(selList)))
            if len(myParams) > 1:
                params = myParams[0]
                for i in range(1, len(myParams)):
                    params = self.appAnswerer.mergeParam(params, myParams[i])
                myParams = self.appAnswerer.splitParam(params, batchSize)

        numRounds = len(self.questions)
        probArray = torch.FloatTensor(len(candQues), len(candImgIdxs), numA).fill_(0)
        if self.dataset.useGPU:
            probArray = probArray.cuda()

        candAns = []
        candAnsLens = []
        numQues = len(candQues)

        # Conditioning
        if ansParams is None or noHistory:
            for i in range(iterTime):
                image = imgList[i*batchSize:(i+1)*batchSize, :]
                caption = self.caption.repeat(batchSize, 1)
                captionLens = self.captionLens.repeat(batchSize)
                abotParams = None
                self.appAnswerer.eval(), self.appAnswerer.reset()
                self.appAnswerer.observe(-1, image=image, caption=caption, captionLens=captionLens)

                if not noHistory:
                    for round in range(len(self.questions)):
                        ques = self.questions[round].repeat(batchSize, 1)
                        ans = self.answers[round].repeat(batchSize, 1)
                        quesLens = self.quesLens[round].unsqueeze(0).repeat(batchSize, 1)
                        ansLens = self.ansLens[round].unsqueeze(0).repeat(batchSize, 1)
                        self.appAnswerer.observe(round, ques=ques, quesLens=quesLens)
                        self.appAnswerer.observe(round, ans=ans, ansLens=ansLens)
                        self.appAnswerer.forward()

                abotParams = self.appAnswerer.exportParams()
                myParams.append(abotParams)


        # First Round: P(a(q(i),c(j))|q(i),c(j))
        ques = candQues.repeat(1, batchSize).view(batchSize*numQues, -1)
        quesLens = candQuesLens.repeat(1, batchSize).view(batchSize*numQues, -1)
        for i in range(iterTime):
            self.appAnswerer.importParams(myParams[i], rep=numQues)
            self.appAnswerer.observe(len(self.questions), ques=ques, quesLens=quesLens)

            if randAnswers is None or randAnsLens is None:
                if aBeamSize != 1:
                    appAns, appAnsLens = self.appAnswerer.forwardDecode(inference='greedy',
                                                                        beamSize=aBeamSize, topk=aBeamSize)
                    if appAns.dim() == 2:
                        appAns = appAns.unsqueeze(1)
                        appAnsLens = appAnsLens.unsqueeze(1)

                    appAns = torch.cat(appAns.unsqueeze(1).split(batchSize, 0), 1)  # [numImg, numQ, aBeamSize, L]
                    appAns = appAns.permute(1, 2, 0, 3)  # [numQ, aBeamSize, numImg, L]
                    appAns = appAns.contiguous().view(numQues, aBeamSize*batchSize, -1)[:, :numA, :]    # [numQ, numA, L]
                    appAns = appAns.transpose(0, 1) # [numA, numQ, L]

                    appAnsLens = torch.cat(appAnsLens.unsqueeze(1).split(batchSize, 0), 1)  # [numImg, numQ, aBeamSize]
                    appAnsLens = appAnsLens.permute(1, 2, 0)    # [numQ, aBeamSize*numImg]
                    appAnsLens = appAnsLens.contiguous().view(numQues, aBeamSize*batchSize)[:, :numA]
                    appAnsLens = appAnsLens.transpose(0, 1) # [numA, numQ]

                else:   # aBeamSize == 1
                    appAns, appAnsLens = self.appAnswerer.forwardDecode(inference='greedy')
                    appAns = torch.cat(appAns.unsqueeze(1).split(batchSize, 0), 1)  # [numImg, numQ, L]
                    appAnsLens = torch.cat(appAnsLens.unsqueeze(1).split(batchSize, 0), 1)  # [numImg, numQ]
                    appAns = appAns[:numA, :, :]    # [numA, numQ, L]
                    appAnsLens = appAnsLens[:numA, :]    # [numA, numQ]

            else:   # randA
                appAns, appAnsLens = randAnswers, randAnsLens
                appAns = appAns.unsqueeze(1).repeat(1, numQues, 1)  # [numA, numQ, L]
                appAnsLens = appAnsLens.unsqueeze(1).repeat(1, numQues) # [numA, numQ]

            candAns.append(appAns)
            candAnsLens.append(appAnsLens)
            myParams[i] = self.appAnswerer.exportParams()
        
        candAns = torch.cat(candAns, dim=0) # (numA, queSize, maxLen), candAns[i, j, :] = a(c(i),q(j))
        candAnsLens = torch.cat(candAnsLens, dim=0) # (numA, queSize)


        # Second Round: P(a(q(i),c(k))|q(i),c(j))
        for i in range(iterTime):
            for aIdx in range(numA):    # p(a | Q, C)
                self.appAnswerer.importParams(myParams[i])
                ans = candAns[aIdx, :, :].repeat(1, batchSize).view(batchSize*numQues, -1)
                ansLens = candAnsLens[aIdx, :].unsqueeze(1).repeat(1, batchSize).view(batchSize*numQues, -1)
                self.appAnswerer.observe(len(self.questions), ans=ans, ansLens=ansLens)
                ansLogProbs = self.appAnswerer.forward()
                
                # Remove  <Start> from ans
                padColumn = ans.data.new(batchSize*numQues, 1).fill_(0)
                padColumn = Variable(padColumn)
                target = torch.cat([ans, padColumn], dim=1)[:, 1:]

                ansLogProbs = torch.gather(ansLogProbs, 2, target.unsqueeze(2)).squeeze(2)
                mask = Variable(torch.ByteTensor(ansLogProbs.size()).fill_(0))
                # have to find an elegant way
                cmpIdx = Variable(torch.LongTensor(1))
                if self.dataset.useGPU:
                    cmpIdx = cmpIdx.cuda()
                    mask = mask.cuda()
                for maskIdx in range(ansLogProbs.size()[1]):
                    cmpIdx.data.fill_(maskIdx)
                    mask[:, maskIdx] = torch.ge(cmpIdx, ansLens).byte()
                ansLogProbs.masked_fill_(mask, 0)
                curP = torch.sum(ansLogProbs, dim=1, keepdim=True).data
                curP = torch.cat(curP.split(batchSize, 0), 1).t()
                probArray[:, i*batchSize:(i+1)*batchSize, aIdx] = curP


        # Normalize
        probArray = torch.exp(probArray)
        probSum = torch.sum(probArray, dim=2, keepdim=True)
        probArray /= probSum
        self.dataset.split = original_split
        if returnCandAnswers:
            return probArray, candAns, candAnsLens
        return probArray    # [numQues, numImg, numA]

    def guess_sigRnd(self, ansParams=None, prior=None, round=None, batchSize=2048,
                     noHistory=False):
        # Guess image for single round using latest question & answer.
        
        # Arugments:
        #     appAnswerers: 
        #         A list of answerer objects returned from last round's guess.
        #         If it's None, new answerer will be used.
        #     prior:  
        #         (Image, ), Probability distribution from last round. 
        #         If it's None, self.prior will be used.
        #     batchSize: 
        #         # images are used in a batch to calculate the posterior.

        # Return:
        #     appAnswerers:
        #         A list of answerer objects storing necessary prior information 
        #         for next round.
        #     posterior: 
        #         (Image, ), the posterior probability for each image.
        if prior is None: 
            posterior = self.prior.clone()
            if self.dataset.useGPU:
                posterior = posterior.cuda()
        else:
            posterior = prior
        assert len(self.answers) == len(self.questions), "Didn't observe full QA pairs!"
        assert self.dataset is not None, "Please set dataset!"

        original_split = self.dataset.split
        self.dataset.split = self.dataset_split
        dataloader = DataLoader(
            self.dataset,
            batch_size=batchSize,
            shuffle=False,
            num_workers=0,
            collate_fn=self.dataset.collate_fn)

        accuIdx = 0
        if ansParams is None or noHistory:
            ansParams = []
            for idx, batch in enumerate(dataloader):
                if self.dataset.useGPU:
                    batch = {key: v.cuda() for key, v in batch.items() \
                                                    if hasattr(v, 'cuda')}
                else:
                    batch = {key: v.contiguous() for key, v in batch.items() \
                                                    if hasattr(v, 'cuda')}

                image = Variable(batch['img_feat'], volatile=True) # (batch_size, img_embed)

                # To avoid dimension mismatch at the end of epoch
                batchSize = image.shape[0]

                caption = self.caption.repeat(batchSize, 1)
                captionLens = self.captionLens.repeat(batchSize)
                
                self.appAnswerer.eval(), self.appAnswerer.reset()
                self.appAnswerer.observe(-1, image=image, caption=caption, captionLens=captionLens)

                ques = self.questions[-1].repeat(batchSize, 1)
                ans = self.answers[-1].repeat(batchSize, 1)
                quesLens = self.quesLens[-1].unsqueeze(0).repeat(batchSize, 1)
                ansLens = self.ansLens[-1].unsqueeze(0).repeat(batchSize, 1)
                self.appAnswerer.observe(round, ques=ques, quesLens=quesLens)
                self.appAnswerer.observe(round, ans=ans, ansLens=ansLens)
                ansLogProbs = self.appAnswerer.forward(debug=False)
                ansParams.append([batchSize, self.appAnswerer.exportParams(deepcopy=False)])

                # Remove  <Start> from ans
                padColumn = ans.data.new(batchSize, 1).fill_(0)
                padColumn = Variable(padColumn)
                target = torch.cat([ans, padColumn], dim=1)[:, 1:]

                ansLogProbs = torch.gather(ansLogProbs, 2, target.unsqueeze(2)).squeeze(2)
                curP = torch.sum(ansLogProbs[:, :self.ansLens[round].data[0]], dim=1).data
                posterior[accuIdx:accuIdx+batchSize] += curP
                accuIdx += batchSize
        else:
            for idx, bp in enumerate(ansParams):
                batchSize, params = bp
                self.appAnswerer.importParams(params)

                ques = self.questions[-1].repeat(batchSize, 1)
                ans = self.answers[-1].repeat(batchSize, 1)
                quesLens = self.quesLens[-1].unsqueeze(0).repeat(batchSize, 1)
                ansLens = self.ansLens[-1].unsqueeze(0).repeat(batchSize, 1)
                self.appAnswerer.observe(round, ques=ques, quesLens=quesLens)
                self.appAnswerer.observe(round, ans=ans, ansLens=ansLens)
                ansLogProbs = self.appAnswerer.forward(debug=False)
                ansParams[idx][1] = self.appAnswerer.exportParams(deepcopy=False)

                # Remove  <Start> from ans
                padColumn = ans.data.new(batchSize, 1).fill_(0)
                padColumn = Variable(padColumn)
                target = torch.cat([ans, padColumn], dim=1)[:, 1:]

                ansLogProbs = torch.gather(ansLogProbs, 2, target.unsqueeze(2)).squeeze(2)
                # print('ansLogProbs', ansLogProbs)
                curP = torch.sum(ansLogProbs[:, :self.ansLens[round].data[0]], dim=1).data
                posterior[accuIdx:accuIdx+batchSize] += curP
                accuIdx += batchSize

        self.dataset.split = original_split
        return ansParams, posterior

    def guess(self, batchSize=2048, noHistory=False):
        '''
        Assume that guesser will be called after observation is done

        Arguments:

        Return:
            (Round, Images) For each round, p(c|a)
        '''
        assert len(self.answers) == len(self.questions), "Didn't observe full QA pairs!"
        assert self.dataset is not None, "Please set dataset!"

        original_split = self.dataset.split
        self.dataset.split = self.dataset_split
        dataloader = DataLoader(
            self.dataset,
            batch_size=batchSize,
            shuffle=False,
            num_workers=0,
            collate_fn=self.dataset.collate_fn)

        numRounds = len(self.questions)

        p_a_history = torch.zeros(
            (numRounds + 1, self.dataset.numDataPoints[self.dataset_split]))
        if self.dataset.useGPU:
            p_a_history = p_a_history.cuda()
        if self.prior is not None:
            p_a_history[0, :] = self.prior
        
        accuIdx = 0
        for idx, batch in enumerate(dataloader):
            if self.dataset.useGPU:
                batch = {key: v.cuda() for key, v in batch.items() \
                                                if hasattr(v, 'cuda')}
            else:
                batch = {key: v.contiguous() for key, v in batch.items() \
                                                if hasattr(v, 'cuda')}

            image = Variable(batch['img_feat'], volatile=True) # (batch_size, img_embed)

            # To avoid dimension mismatch at the end of epoch
            batchSize = image.shape[0] 

            caption = self.caption.repeat(batchSize, 1)
            captionLens = self.captionLens.repeat(batchSize)

            if not noHistory:
                self.appAnswerer.eval(), self.appAnswerer.reset()
                self.appAnswerer.observe(-1, image=image, caption=caption, captionLens=captionLens)
                for round in range(numRounds):
                    ques = self.questions[round].repeat(batchSize, 1)
                    ans = self.answers[round].repeat(batchSize, 1)
                    quesLens = self.quesLens[round].unsqueeze(0).repeat(batchSize, 1)
                    ansLens = self.ansLens[round].unsqueeze(0).repeat(batchSize, 1)
                    self.appAnswerer.observe(round, ques=ques, quesLens=quesLens)
                    self.appAnswerer.observe(round, ans=ans, ansLens=ansLens)
                    ansLogProbs = self.appAnswerer.forward(debug=False)

                    # Remove  <Start> from ans
                    padColumn = ans.data.new(batchSize, 1).fill_(0)
                    padColumn = Variable(padColumn)
                    target = torch.cat([ans, padColumn], dim=1)[:, 1:]

                    ansLogProbs = torch.gather(ansLogProbs, 2, target.unsqueeze(2)).squeeze(2)
                    curP = torch.sum(ansLogProbs[:, :self.ansLens[round].data[0]], dim=1).data
                    p_a_history[round+1, accuIdx:accuIdx+batchSize] = p_a_history[round, accuIdx:accuIdx+batchSize] + curP
            else:
                for round in range(numRounds):
                    self.appAnswerer.eval(), self.appAnswerer.reset()
                    self.appAnswerer.observe(-1, image=image, caption=caption, captionLens=captionLens)
                    ques = self.questions[round].repeat(batchSize, 1)
                    ans = self.answers[round].repeat(batchSize, 1)
                    quesLens = self.quesLens[round].unsqueeze(0).repeat(batchSize, 1)
                    ansLens = self.ansLens[round].unsqueeze(0).repeat(batchSize, 1)
                    self.appAnswerer.observe(round, ques=ques, quesLens=quesLens)
                    self.appAnswerer.observe(round, ans=ans, ansLens=ansLens)
                    ansLogProbs = self.appAnswerer.forward(debug=False)

                    # Remove  <Start> from ans
                    padColumn = ans.data.new(batchSize, 1).fill_(0)
                    padColumn = Variable(padColumn)
                    target = torch.cat([ans, padColumn], dim=1)[:, 1:]

                    ansLogProbs = torch.gather(ansLogProbs, 2, target.unsqueeze(2)).squeeze(2)
                    curP = torch.sum(ansLogProbs[:, :self.ansLens[round].data[0]], dim=1).data
                    p_a_history[round + 1, accuIdx:accuIdx + batchSize] = p_a_history[round,
                                                                          accuIdx:accuIdx + batchSize] + curP

            accuIdx += batchSize
            
        self.dataset.split = original_split
        return p_a_history
        
    def forward(self):
        return self.questioner.forward() 
    
    def aForward(self):
        return self.appAnswerer.forward()
    
    def predictImage(self):
        return self.questioner.predictImage()
    
    def forwardDecode(self, inference='sample', beamSize=1, maxSeqLen=20, topk=1, retLogProbs=False, gamma=0, delta=0):
        return self.questioner.forwardDecode(inference=inference, beamSize=beamSize, maxSeqLen=maxSeqLen, topk=topk, retLogProbs=retLogProbs, gamma=gamma, delta=delta)