示例#1
0
    def _forward(self, ofc_feats, oatt_feats,densecap, seq, att_masks=None,personality=None):
        batch_size = self.batch_size
        seq_per_img = seq.shape[0] // batch_size
        outputs = torch.zeros(batch_size*seq_per_img, seq.size(1) - 1, self.vocab_size+1,dtype=torch.float).cuda()
        # Prepare the features
        rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks = self._prepare_feature(ofc_feats, oatt_feats,att_masks)
        # pp_att_feats is used for attention, we cache it in advance to reduce computation cost
        encodestate = self.enc_init_hidden(batch_size*5)
        encoder_cells =[]
        for k in range(densecap.size(-1)):
            w =  densecap[:,:,k].clone()
            embedw = self.embed(w)
            embedw = embedw.contiguous().view(-1,embedw.size(-1)).contiguous()
            encodestate= self.encoder(embedw, (encodestate[0],encodestate[1])) 
            encoder_cells.append(encodestate[1].contiguous().view(batch_size,5,encodestate[1].size(-1)))       
        hstate, cstate = encodestate
        att_feats = torch.stack(encoder_cells).cuda()
        att_feats = att_feats.contiguous().permute(1,2,0,3)

        fc_feats =  hstate.contiguous().view(batch_size,5,encodestate[0].size(-1))
        fc_feats =  fc_feats.contiguous().view(batch_size,-1) 
        p_att_feats = self.ctx2att_t(att_feats)
        decodestate = self.init_hidden(batch_size*seq_per_img)
        if seq_per_img > 1:
            fc_feats, att_feats, p_att_feats, att_masks = utils.repeat_tensors(seq_per_img,
                    [fc_feats, att_feats, p_att_feats, att_masks])
            rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks = utils.repeat_tensors(seq_per_img,[rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks])

        for i in range(seq.size(1) - 1):
            if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
                sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                    #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                    # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
                    prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
            else:
                it = seq[:, i].clone()          
            # break if all the sequences end
            if i >= 1 and seq[:, i].sum() == 0:
                break
             
            output,  decodestate = self.get_logprobs_state(it,personality, fc_feats, att_feats, p_att_feats, att_masks,rp_fc_feats,               rp_att_feats, rpp_att_feats, rp_att_masks, decodestate)
            outputs[:, i] = output
        return outputs
    def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
        att_feats, att_masks = self.clip_att(att_feats, att_masks)

        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)

        if att_masks is None:
            att_masks = att_feats.new_ones(att_feats.shape[:2],
                                           dtype=torch.long)
        att_masks = att_masks.unsqueeze(-2)

        if seq is not None:
            # crop the last one
            seq = seq[:, :-1]
            seq_mask = (seq.data > 0)
            seq_mask[:, 0] = 1  # bos

            seq_mask = seq_mask.unsqueeze(-2)
            seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)

            seq_per_img = seq.shape[0] // att_feats.shape[0]
            if seq_per_img > 1:
                att_feats, att_masks = utils.repeat_tensors(
                    seq_per_img, [att_feats, att_masks])
        else:
            seq_mask = None

        return att_feats, seq, att_masks, seq_mask
    def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        seq_per_img = seq.shape[0] // batch_size
        state = self.init_hidden(batch_size * seq_per_img)

        outputs = fc_feats.new_zeros(batch_size * seq_per_img,
                                     seq.size(1) - 1, self.vocab_size + 1)

        # Prepare the features
        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(
            fc_feats, att_feats, att_masks)
        # pp_att_feats is used for attention, we cache it in advance to reduce computation cost

        if seq_per_img > 1:
            p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(
                seq_per_img,
                [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks])

        for i in range(seq.size(1) - 1):
            # break if all the sequences end
            if i >= 1 and seq[:, i].sum() == 0:
                break

            output, state = self.get_logprobs_state(seq[:, i], p_fc_feats,
                                                    p_att_feats, pp_att_feats,
                                                    p_att_masks, state)
            outputs[:, i] = output

        return outputs
示例#4
0
    def _sample_beam(self,
                     semmantic_feat,
                     semantic1_feat,
                     att_feats,
                     att1_feat,
                     box_feat,
                     box1_feat,
                     opt={}):
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        sample_n = opt.get('sample_n', 10)
        # when sample_n == beam_size then each beam is a sample.
        assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
        batch_size = att_feats.size(0)

        new_semantic_feat, new_semantic1_feat = self.att_feat(
            semmantic_feat, semantic1_feat, att_feats, att1_feat, box_feat,
            box1_feat)

        assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
        seq = att_feats.new_zeros((batch_size * sample_n, self.seq_length),
                                  dtype=torch.long)
        seqLogprobs = att_feats.new_zeros(batch_size * sample_n,
                                          self.seq_length, self.vocab_size + 1)
        # lets process every image independently for now, for simplicity

        self.done_beams = [[] for _ in range(batch_size)]

        state = self.init_hidden(batch_size)

        # first step, feed bos
        it = att_feats.new_zeros([batch_size], dtype=torch.long)
        logprobs, state = self.get_logprobs_state(it, new_semantic_feat,
                                                  new_semantic1_feat, state)

        new_semantic_feat, new_semantic1_feat = utils.repeat_tensors(
            beam_size, [new_semantic_feat, new_semantic1_feat])
        self.done_beams = self.beam_search(state,
                                           logprobs,
                                           new_semantic_feat,
                                           new_semantic1_feat,
                                           opt=opt)
        for k in range(batch_size):
            if sample_n == beam_size:
                for _n in range(sample_n):
                    seq_len = self.done_beams[k][_n]['seq'].shape[0]
                    seq[k * sample_n +
                        _n, :seq_len] = self.done_beams[k][_n]['seq']
                    seqLogprobs[k * sample_n +
                                _n, :seq_len] = self.done_beams[k][_n]['logps']
            else:
                seq_len = self.done_beams[k][0]['seq'].shape[0]
                seq[k, :seq_len] = self.done_beams[k][0][
                    'seq']  # the first beam has highest cumulative score
                seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
        # return the samples and their log likelihoods
        return seq, seqLogprobs
示例#5
0
    def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        sample_n = opt.get('sample_n', 10)
        # when sample_n == beam_size then each beam is a sample.
        assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
        batch_size = fc_feats.size(0)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(
            fc_feats, att_feats, att_masks)

        assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
        seq = fc_feats.new_zeros((batch_size * sample_n, self.seq_length),
                                 dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size * sample_n,
                                         self.seq_length, self.vocab_size + 1)
        # lets process every image independently for now, for simplicity

        self.done_beams = [[] for _ in range(batch_size)]
        for k in range(batch_size):
            state = self.init_hidden(beam_size)
            tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(
                beam_size, [
                    p_fc_feats[k:k + 1], p_att_feats[k:k + 1],
                    pp_att_feats[k:k + 1],
                    p_att_masks[k:k + 1] if att_masks is not None else None
                ])

            for t in range(1):
                if t == 0:  # input <bos>
                    it = fc_feats.new_zeros([beam_size], dtype=torch.long)

                logprobs, state = self.get_logprobs_state(
                    it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats,
                    tmp_att_masks, state)

            self.done_beams[k] = self.old_beam_search(state,
                                                      logprobs,
                                                      tmp_fc_feats,
                                                      tmp_att_feats,
                                                      tmp_p_att_feats,
                                                      tmp_att_masks,
                                                      opt=opt)
            if sample_n == beam_size:
                for _n in range(sample_n):
                    seq[k * sample_n + _n, :] = self.done_beams[k][_n]['seq']
                    seqLogprobs[k * sample_n +
                                _n, :] = self.done_beams[k][_n]['logps']
            else:
                seq[k, :] = self.done_beams[k][0][
                    'seq']  # the first beam has highest cumulative score
                seqLogprobs[k, :] = self.done_beams[k][0]['logps']
        # return the samples and their log likelihoods
        return seq, seqLogprobs
示例#6
0
    def get_logprobs_state(self, it,personality, fc_feats, att_feats, p_att_feats, att_masks,rp_fc_feats,rp_att_feats, rpp_att_feats, rp_att_masks, state):
        # 'it' contains a word index
        batch_size = personality.size(0)
        xt = self.embed(it)# 500*100
        seq_per_img = xt.size(0)//batch_size
        if personality is not None:
            pers_encoded = self.personality_encoder(personality.nonzero(as_tuple=True)[1])
            pers_encoded = utils.repeat_tensors(seq_per_img,pers_encoded)
            xt=torch.cat((xt,pers_encoded),1)
        output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks,rp_fc_feats,rp_att_feats, rpp_att_feats, rp_att_masks)
        logitoutput = self.logit(output)
        finallogprobs = F.log_softmax(logitoutput, dim=1)

        return finallogprobs, state
示例#7
0
    def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        if seq.ndim == 3:  # B * seq_per_img * seq_len
            seq = seq.reshape(-1, seq.shape[2])
        seq_per_img = seq.shape[0] // batch_size
        state = self.init_hidden(batch_size * seq_per_img)

        outputs = fc_feats.new_zeros(batch_size * seq_per_img,
                                     seq.size(1) - 1, self.vocab_size + 1)

        # Prepare the features
        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(
            fc_feats, att_feats, att_masks)
        # pp_att_feats is used for attention, we cache it in advance to reduce computation cost

        if seq_per_img > 1:
            p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(
                seq_per_img,
                [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks])

        for i in range(seq.size(1) - 1):
            if self.training and i >= 1 and self.ss_prob > 0.0:  # otherwiste no need to sample
                sample_prob = fc_feats.new(batch_size * seq_per_img).uniform_(
                    0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(outputs[:, i - 1].detach(
                    ))  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
            else:
                it = seq[:, i].clone()
            # break if all the sequences end
            if i >= 1 and seq[:, i].sum() == 0:
                break

            output, state = self.get_logprobs_state(it, p_fc_feats,
                                                    p_att_feats, pp_att_feats,
                                                    p_att_masks, state)
            outputs[:, i] = output

        return outputs
    def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
        sample_method = opt.get('sample_method', 'greedy')
        beam_size = opt.get('beam_size', 1)
        sample_n = int(opt.get('sample_n', 1))
        if beam_size > 1:
            return self._sample_beam(fc_feats, att_feats, att_masks, opt)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size * sample_n)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(
            fc_feats, att_feats, att_masks)

        if sample_n > 1:
            p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(
                sample_n, [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks])

        seq = fc_feats.new_zeros((batch_size * sample_n, self.seq_length),
                                 dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size * sample_n,
                                         self.seq_length, self.vocab_size + 1)
        for t in range(self.seq_length):
            if t == 0:  # input <bos>
                it = fc_feats.new_zeros(batch_size * sample_n,
                                        dtype=torch.long)

            logprobs, state = self.get_logprobs_state(it, p_fc_feats,
                                                      p_att_feats,
                                                      pp_att_feats,
                                                      p_att_masks, state)

            # sample the next word
            it, _ = self.sample_next_word(logprobs, sample_method)

            # stop when all finished
            if t == 0:
                unfinished = it > 0
            else:
                unfinished = unfinished * (it > 0)
            it = it * unfinished.type_as(it)

            seq[:, t] = it
            seqLogprobs[:, t] = logprobs
            # quit loop if all sequences have finished
            if unfinished.sum() == 0:
                break

        return seq.detach(), seqLogprobs
示例#9
0
    def _forward(self, semantic_feat, semantic1_feat, att_feats, att1_feats,
                 box_feat, box1_feat, seq):
        batch_size = att_feats.size(0)
        if seq.ndim == 3:
            seq = seq.reshape(-1, seq.shape[2])
        seq_per_img = seq.shape[0] // batch_size
        state = self.init_hidden(batch_size * seq_per_img)
        outputs = att_feats.new_zeros(batch_size * seq_per_img,
                                      seq.size(1) - 1, self.vocab_size + 1)

        # att_feat
        new_semantic_feat, new_semantic1_feat = self.att_feat(
            semantic_feat, semantic1_feat, att_feats, att1_feats, box_feat,
            box1_feat)

        if seq_per_img > 1:
            new_semantic_feat, new_semantic1_feat = utils.repeat_tensors(
                seq_per_img, [new_semantic_feat, new_semantic1_feat])

        for i in range(seq.size(1) - 1):
            if self.training and i >= 1 and self.ss_prob > 0.0:  # otherwiste no need to sample
                sample_prob = att_feats.new(batch_size * seq_per_img).uniform_(
                    0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(outputs[:, i - 1].detach(
                    ))  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
            else:
                it = seq[:, i].clone()
            # break if all the sequences end
            if i >= 1 and seq[:, i].sum() == 0:
                break

            output, state = self.get_logprobs_state(it, new_semantic_feat,
                                                    new_semantic1_feat, state)
            outputs[:, i] = output

        return outputs
    def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        seq_per_img = seq.shape[0] // batch_size
        state = self.init_hidden(batch_size * seq_per_img)
        outputs = []

        if seq_per_img > 1:
            fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)

        for i in range(seq.size(1)):
            if i == 0:
                xt = self.img_embed(fc_feats)
            else:
                if self.training and i >= 2 and self.ss_prob > 0.0:  # otherwiste no need to sample
                    sample_prob = fc_feats.data.new(
                        batch_size * seq_per_img).uniform_(0, 1)
                    sample_mask = sample_prob < self.ss_prob
                    if sample_mask.sum() == 0:
                        it = seq[:, i - 1].clone()
                    else:
                        sample_ind = sample_mask.nonzero().view(-1)
                        it = seq[:, i - 1].data.clone()
                        #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                        #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                        prob_prev = torch.exp(
                            outputs[-1].data
                        )  # fetch prev distribution: shape Nx(M+1)
                        it.index_copy_(
                            0, sample_ind,
                            torch.multinomial(prob_prev,
                                              1).view(-1).index_select(
                                                  0, sample_ind))
                else:
                    it = seq[:, i - 1].clone()
                # break if all the sequences end
                if i >= 2 and seq[:, i - 1].data.sum() == 0:
                    break
                xt = self.embed(it)

            output, state = self.core(xt.unsqueeze(0), state)
            output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))),
                                   dim=1)
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
示例#11
0
        def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
            local_time = t - divm
            unaug_logprobs = logprobs.clone()
            batch_size = beam_seq_table[0].shape[0]

            if divm > 0:
                change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
                for prev_choice in range(divm):
                    prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
                    for prev_labels in range(bdash):
                        change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
                
                if local_time == 0:
                    logprobs = logprobs - change * diversity_lambda
                else:
                    logprobs = logprobs - utils.repeat_tensors(bdash, change) * diversity_lambda

            return logprobs, unaug_logprobs
    def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
        beam_size = opt.get('beam_size', 10)
        batch_size = fc_feats.size(0)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(
            fc_feats, att_feats, att_masks)

        # let's assume this for now
        assert beam_size <= self.vocab_size + 1
        seq = fc_feats.new_zeros((batch_size, self.seq_length),
                                 dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length,
                                         self.vocab_size + 1)
        # let's process every image independently for now, for simplicity

        state = self.init_hidden(batch_size)
        # first step, feed bos
        it = fc_feats.new_zeros([batch_size], dtype=torch.long)

        # logprobs shape is batch_size x (vocab_size + 1)
        logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats,
                                                  pp_att_feats, p_att_masks,
                                                  state)
        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(
            beam_size, [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks])
        done_beams = self.beam_search(state,
                                      logprobs,
                                      p_fc_feats,
                                      p_att_feats,
                                      pp_att_feats,
                                      p_att_masks,
                                      opt=opt)
        for k in range(batch_size):
            seq_len = done_beams[k][0]['seq'].shape[0]
            seq[k, :seq_len] = done_beams[k][0][
                'seq']  # the first beam has the highest cumulative score
            seqLogprobs[k, :seq_len] = done_beams[k][0]['logps']
        return seq, seqLogprobs
    def _sample(self, fc_feats, att_feats, topic_vec, att_masks=None, opt={}):

        sample_method = opt.get('sample_method', 'greedy')
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        sample_n = int(opt.get('sample_n', 1))
        group_size = opt.get('group_size', 1)
        output_logsoftmax = opt.get('output_logsoftmax', 1)
        decoding_constraint = opt.get('decoding_constraint', 0)
        block_trigrams = opt.get('block_trigrams', 0)
        remove_bad_endings = opt.get('remove_bad_endings', 0)
        if beam_size > 1:
            return self._sample_beam(fc_feats, att_feats, att_masks, opt)
        if group_size > 1:
            return self._diverse_sample(fc_feats, att_feats, att_masks, opt)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size*sample_n)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        if sample_n > 1:
            p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
                [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
            )

        trigrams = [] # will be a list of batch_size dictionaries
        
        seq = fc_feats.new_zeros((batch_size*sample_n, self.seq_length), dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
        decoder_states = []

        for t in range(self.seq_length + 1):
            if t == 0: # input <bos>
                it = fc_feats.new_zeros(batch_size*sample_n, dtype=torch.long)

            logprobs, state, output = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, topic_vec, output_logsoftmax=output_logsoftmax)
            
            if decoding_constraint and t > 0:
                tmp = logprobs.new_zeros(logprobs.size())
                tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
                logprobs = logprobs + tmp

            if remove_bad_endings and t > 0:
                tmp = logprobs.new_zeros(logprobs.size())
                prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
                # Make it impossible to generate bad_endings
                tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
                logprobs = logprobs + tmp

            # Mess with trigrams
            # Copy from https://github.com/lukemelas/image-paragraph-captioning
            if block_trigrams and t >= 3:
                # Store trigram generated at last step
                prev_two_batch = seq[:,t-3:t-1]
                for i in range(batch_size): # = seq.size(0)
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    current  = seq[i][t-1]
                    if t == 3: # initialize
                        trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
                    elif t > 3:
                        if prev_two in trigrams[i]: # add to list
                            trigrams[i][prev_two].append(current)
                        else: # create list
                            trigrams[i][prev_two] = [current]
                # Block used trigrams at next step
                prev_two_batch = seq[:,t-2:t]
                mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
                for i in range(batch_size):
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    if prev_two in trigrams[i]:
                        for j in trigrams[i][prev_two]:
                            mask[i,j] += 1
                # Apply mask to log probs
                #logprobs = logprobs - (mask * 1e9)
                alpha = 2.0 # = 4
                logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)

            # sample the next word
            if t == self.seq_length: # skip if we achieve maximum length
                break
            it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)

            # stop when all finished
            if t == 0:
                unfinished = it > 0
            else:
                unfinished = unfinished * (it > 0)
            it = it * unfinished.type_as(it)
            seq[:,t] = it
            seqLogprobs[:,t] = logprobs
            # quit loop if all sequences have finished
            if unfinished.sum() == 0:
                break

            decoder_states.append(output)

        return seq, seqLogprobs, decoder_states
示例#14
0
    def _samplen(self, ofc_feats, oatt_feats,densecap, att_masks=None,personality=None, opt={}):

        sample_method = opt.get('sample_method', 'greedy')
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        opt['block_trigrams'] =0
        opt['remove_bad_endings'] =1
        opt['decoding_constraint'] =1
        decoding_constraint = opt.get('decoding_constraint', 1)
        block_trigrams = opt.get('block_trigrams', 1)
        remove_bad_endings = opt.get('remove_bad_endings', 1)
        sample_n = int(opt.get('sample_n', 3))
        no_unk=1
        if beam_size > 1:
            return self._sample_beam(ofc_feats, oatt_feats,densecap, att_masks,personality, opt)

        batch_size = densecap.size(0)
        # Prepare the features
        rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks = self._prepare_feature(ofc_feats, oatt_feats,att_masks)
        # pp_att_feats is used for attention, we cache it in advance to reduce computation cost
        if sample_n > 1:
            personality, densecap, rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks = utils.repeat_tensors(sample_n,
                [personality,densecap,rp_fc_feats, rp_att_feats, rpp_att_feats, rp_att_masks]
            ) 
        encodestate = self.enc_init_hidden(batch_size*5*sample_n)
        encoder_cells =[]
        for k in range(densecap.size(-1)):
            w =  densecap[:,:,k].clone()
            embedw = self.embed(w)
            embedw = embedw.contiguous().view(-1,embedw.size(-1)).contiguous()
            encodestate= self.encoder(embedw, (encodestate[0],encodestate[1]))
            encoder_cells.append(encodestate[1].contiguous().view(batch_size*sample_n,5,encodestate[1].size(-1)))
        hstate, cstate = encodestate
        att_feats = torch.stack(encoder_cells).cuda()
        p_att_feats = att_feats.contiguous().permute(1,2,0,3)

        fc_feats =  hstate.contiguous().view(batch_size*sample_n,5,encodestate[0].size(-1))
        p_fc_feats =  fc_feats.contiguous().view(batch_size*sample_n,-1)
        pp_att_feats = self.ctx2att_t(p_att_feats)
        p_att_masks =  att_masks

        decodestate = self.init_hidden(batch_size*sample_n)
      
        #p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        trigrams = [] # will be a list of batch_size dictionaries
        alogprobs1 = torch.zeros(batch_size*sample_n,self.seq_length+1, self.vocab_size+1).cuda()
        alogprobs= torch.zeros(batch_size*sample_n,self.seq_length+1, self.vocab_size+1).cuda()
        for bk in range(alogprobs.size(0)):
            alogprobs[bk]=nn.LogSoftmax(dim=1)(alogprobs1[bk])
        
        seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
        seqLogprobs = torch.zeros(batch_size*sample_n, self.seq_length,dtype=torch.float).cuda()
        for t in range(self.seq_length + 1):
            if t == 0: # input <bos>
                it = fc_feats.new_zeros(batch_size*sample_n, dtype=torch.long)
            logprobs, decodestate = self.get_logprobs_state(it,personality, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks,rp_fc_feats,rp_att_feats, rpp_att_feats, rp_att_masks, decodestate)
            if decoding_constraint and t > 0:
                tmp = logprobs.new_zeros(logprobs.size())
                tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-10e20'))
                logprobs = logprobs + tmp

            if remove_bad_endings and t > 0:
                tmp = logprobs.new_zeros(logprobs.size())
                prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
                # Impossible to generate remove_bad_endings
                tmp[torch.from_numpy(prev_bad.astype(np.bool_)), 0] = float('-10e20')
                logprobs = logprobs + tmp

            # Mess with trigrams
            if block_trigrams and t >= 3:
                # Store trigram generated at last step
                prev_two_batch = seq[:,t-3:t-1]
                for i in range(batch_size): # = seq.size(0)
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    current  = seq[i][t-1]
                    if t == 3: # initialize
                        trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
                    elif t > 3:
                        if prev_two in trigrams[i]: # add to list
                            trigrams[i][prev_two].append(current)
                        else: # create list
                            trigrams[i][prev_two] = [current]
                # Block used trigrams at next step
                prev_two_batch = seq[:,t-2:t]
                mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
                for i in range(batch_size):
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    if prev_two in trigrams[i]:
                        for j in trigrams[i][prev_two]:
                            mask[i,j] += 1
                # Apply mask to log probs
                #logprobs = logprobs - (mask * 1e9)
                alpha = 10e20 # = 4
                logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
            if no_unk==1:
                mask2 = torch.zeros(logprobs.size(), requires_grad=False).cuda()
                mask2[:,mask2.size(1)-1] =-10e20
                logprobs= logprobs+ mask2
            logprobs = F.log_softmax(logprobs,dim=-1)
            # sample the next word
            if t == self.seq_length: # skip if we achieve maximum length
                break
            
            it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)

            # stop when all finished
            if t == 0:
                unfinished = it > 0
            else:
                unfinished = unfinished * (it > 0)
            it = it * unfinished.type_as(it)
            seq[:,t] = it
            seqLogprobs[:,t] = sampleLogprobs.view(-1)
            # quit loop if all sequences have finished
            alogprobs[:, t] = logprobs
            if unfinished.sum() == 0:
                break

        return seq, seqLogprobs, alogprobs