Ejemplo n.º 1
0
    def _sample_beam(self, fc_feats, attri_feats, att_feats, att_masks=None, opt={}):
        beam_size = opt.get('beam_size', 10)
        batch_size = fc_feats.size(0)

        att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks)
        memory = self.model.encode(att_feats, att_masks)

        assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
        # lets process every image independently for now, for simplicity

        self.done_beams = [[] for _ in range(batch_size)]
        for k in range(batch_size):
            state = None
            tmp_memory = memory[k:k + 1].expand(*((beam_size,) + memory.size()[1:])).contiguous()
            tmp_att_masks = att_masks[k:k + 1].expand(*((beam_size,) + att_masks.size()[1:])).contiguous() if att_masks is not None else None

            for t in range(1):
                if t == 0:  # input <bos>
                    it = Variable(fc_feats.data.new(beam_size).long().zero_()) if utils.under_0_4() else fc_feats.new_zeros([beam_size], dtype=torch.long)

                logprobs, state = self.get_logprobs_state(it, tmp_memory, tmp_att_masks, state)

            self.done_beams[k] = self.beam_search(state, logprobs, tmp_memory, tmp_att_masks, opt=opt)
            seq[:, k] = self.done_beams[k][0]['seq']  # the first beam has highest cumulative score
            seqLogprobs[:, k] = self.done_beams[k][0]['logps']
        # return the samples and their log likelihoods
        if utils.under_0_4():
            return Variable(seq.transpose(0, 1)), Variable(seqLogprobs.transpose(0, 1))
        else:
            return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
Ejemplo n.º 2
0
 def init_hidden(self, bsz):
     weight = next(self.parameters()).data
     if self.rnn_type == 'lstm':
         return (weight.new_zeros(self.num_layers, bsz, self.rnn_size) if utils.under_0_4() else weight.new_zeros(self.num_layers, bsz, self.rnn_size),
                 weight.new_zeros(self.num_layers, bsz, self.rnn_size) if utils.under_0_4() else weight.new_zeros(self.num_layers, bsz, self.rnn_size))
     else:
         return weight.new_zeros(self.num_layers, bsz, self.rnn_size) if utils.under_0_4() else weight.new_zeros(self.num_layers, bsz, self.rnn_size)
Ejemplo n.º 3
0
    def sample(self, fc_feats, att_feats, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        if beam_size > 1:
            return self.sample_beam(fc_feats, att_feats, opt)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(fc_feats)

        seq = []
        seqLogprobs = []
        for t in range(self.seq_length + 1):
            if t == 0:  # input <bos>
                it = fc_feats.data.new(batch_size).long().zero_()
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data).cpu(
                    )  # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data,
                                                    temperature)).cpu()
                it = torch.multinomial(prob_prev, 1).cuda()
                sampleLogprobs = logprobs.gather(
                    1,
                    Variable(it, requires_grad=False) if utils.under_0_4() else
                    it)  # gather the logprobs at sampled positions
                it = it.view(
                    -1).long()  # and flatten indices for downstream processing

            xt = self.embed(
                Variable(it, requires_grad=False) if utils.under_0_4() else it)

            if t >= 1:
                # stop when all finished
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished = unfinished * (it > 0)
                if unfinished.sum() == 0:
                    break
                it = it * unfinished.type_as(it)
                seq.append(it)  #seq[t] the input of t+2 time step
                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.core(xt, fc_feats, att_feats, state)
            logprobs = F.log_softmax(self.logit(
                self.dropout(output))) if utils.under_0_4() else F.log_softmax(
                    self.logit(self.dropout(output)), dim=1)

        return torch.cat([_.unsqueeze(1) for _ in seq],
                         1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs],
                                       1)
Ejemplo n.º 4
0
    def _sample(self, fc_feats, attri_feats, att_feats, att_masks=None, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        decoding_constraint = opt.get('decoding_constraint', 0)
        if beam_size > 1:
            return self._sample_beam(fc_feats, attri_feats, att_feats, att_masks, opt)

        batch_size = att_feats.shape[0]

        att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks)

        state = None
        memory = self.model.encode(att_feats, att_masks)

        seq =  Variable(att_feats.data.new(batch_size, self.seq_length).long().zero_()) if utils.under_0_4() else att_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long)
        seqLogprobs = Variable(att_feats.data.new(batch_size, self.seq_length).zero_()) if utils.under_0_4() else att_feats.new_zeros(batch_size, self.seq_length)

        for t in range(self.seq_length + 1):
            if t == 0:  # input <bos>
                it = fc_feats.data.new(batch_size).long().zero_() if utils.under_0_4() else fc_feats.new_zeros(batch_size, dtype=torch.long)
            it = Variable(it, requires_grad=False) if utils.under_0_4() else it
            logprobs, state = self.get_logprobs_state(it, memory, att_masks, state)
            if decoding_constraint and t > 0:
                tmp = output.new_zeros(output.size(0), self.vocab_size + 1)
                tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
                logprobs = logprobs + tmp

            # sample the next word
            if t == self.seq_length:  # skip if we achieve maximum length
                break
            if sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data)  # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data, temperature))
                it = torch.multinomial(prob_prev, 1)
                sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) if utils.under_0_4() else logprobs.gather(1, it)  # gather the logprobs at sampled positions
                it = it.view(-1).long()  # and flatten indices for downstream processing

            # stop when all finished
            if t == 0:
                unfinished = it > 0
            else:
                unfinished = unfinished * (it > 0)
            it = it * unfinished.type_as(it)
            seq[:, t] = it
            seqLogprobs[:, t] = sampleLogprobs.view(-1)
            # quit loop if all sequences have finished
            if unfinished.sum() == 0:
                break

        return seq, seqLogprobs
 def init_hidden(self, bsz):
     weight = next(self.parameters()).data if utils.under_0_4() else next(
         self.parameters())
     return (
         Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())
         if utils.under_0_4() else weight.new_zeros(self.num_layers, bsz,
                                                    self.rnn_size),
         Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())
         if utils.under_0_4() else weight.new_zeros(self.num_layers, bsz,
                                                    self.rnn_size))
Ejemplo n.º 6
0
    def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        if beam_size > 1:
            return self.sample_beam(fc_feats, att_feats, opt)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
        for t in range(self.seq_length + 2):
            if t == 0:
                xt = self.img_embed(fc_feats)
            else:
                if t == 1:  # input <bos>
                    it = fc_feats.data.new(batch_size).long().zero_()
                xt = self.embed(Variable(it, requires_grad=False) if utils.under_0_4() else it)

            output, state = self.core(xt.unsqueeze(0), state)
            logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) if utils.under_0_4() else F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)

            # sample the next word
            if t == self.seq_length + 1:  # skip if we achieve maximum length
                break
            if sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data).cpu()  # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
                it = torch.multinomial(prob_prev, 1).cuda()
                sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False) if utils.under_0_4() else it)  # gather the logprobs at sampled positions
                it = it.view(-1).long()  # and flatten indices for downstream processing

            if t >= 1:
                # stop when all finished
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished = unfinished * (it > 0)
                it = it * unfinished.type_as(it)
                seq[:, t - 1] = it  # seq[t] the input of t+2 time step
                seqLogprobs[:, t - 1] = sampleLogprobs.view(-1)
                if unfinished.sum() == 0:
                    break

        return seq, seqLogprobs
    def get_batch(self, split, batch_size=None):
        batch_size = batch_size or self.batch_size

        # pick an index of the datapoint to load next
        fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32')
        att_batch = np.ndarray((batch_size, 14*14, 2048), dtype = 'float32')
        max_index = self.N
        wrapped = False
        infos = []

        for i in range(batch_size):
            ri = self.iterator
            ri_next = ri + 1
            if ri_next >= max_index:
                ri_next = 0
                wrapped = True
                # wrap back around
            self.iterator = ri_next

            img = skimage.io.imread(self.files[ri])

            if len(img.shape) == 2:
                img = img[:,:,np.newaxis]
                img = np.concatenate((img, img, img), axis=2)

            img = img.astype('float32')/255.0
            img = torch.from_numpy(img.transpose([2,0,1])).cuda()
            img = Variable(preprocess(img), volatile=True) if utils.under_0_4() else preprocess(img)
            if utils.under_0_4():
                tmp_fc, tmp_att = self.my_resnet(img)
            else:
                with torch.no_grad(): tmp_fc, tmp_att = self.my_resnet(img)
            _tmp_att = tmp_att.contiguous().view(-1, 196, 2048)
            fc_batch[i] = tmp_fc.data.cpu().float().numpy()
            att_batch[i] = _tmp_att.data.cpu().float().numpy()

            info_struct = {}
            info_struct['id'] = self.ids[ri]
            info_struct['file_path'] = self.files[ri]
            infos.append(info_struct)

        data = {}
        data['fc_feats'] = fc_batch
        data['att_feats'] = att_batch
        data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
        for i in range(len(att_batch)):
            data['att_masks'][i:(i+1), :att_batch[i].shape[0]] = 1
        data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped}
        data['infos'] = infos

        return data
Ejemplo n.º 8
0
def main(params):
    net = getattr(resnet, params['model'])()
    net.load_state_dict(
        torch.load(os.path.join(params['model_root'],
                                params['model'] + '.pth')))
    my_resnet = myResnet(net)
    my_resnet.cuda()
    my_resnet.eval()

    imgs = json.load(open(params['input_json'], 'r'))
    imgs = imgs['images']
    N = len(imgs)

    seed(123)  # make reproducible

    dir_fc = params['output_dir'] + '_fc'
    dir_att = params['output_dir'] + '_att'
    if not os.path.isdir(dir_fc):
        os.mkdir(dir_fc)
    if not os.path.isdir(dir_att):
        os.mkdir(dir_att)

    for i, img in enumerate(imgs):
        # load the image
        I = skimage.io.imread(
            os.path.join(params['images_root'], img['filepath'],
                         img['filename']))
        # handle grayscale input images
        if len(I.shape) == 2:
            I = I[:, :, np.newaxis]
            I = np.concatenate((I, I, I), axis=2)

        I = I.astype('float32') / 255.0
        I = torch.from_numpy(I.transpose([2, 0, 1])).cuda()
        I = Variable(preprocess(I),
                     volatile=True) if utils.under_0_4() else preprocess(I)
        if utils.under_0_4():
            tmp_fc, tmp_att = my_resnet(I, params['att_size'])
        else:
            with torch.no_grad():
                tmp_fc, tmp_att = my_resnet(I, params['att_size'])
        # write to pkl
        np.save(os.path.join(dir_fc, str(img['cocoid'])),
                tmp_fc.data.cpu().float().numpy())
        np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])),
                            feat=tmp_att.data.cpu().float().numpy())

        if i % 1000 == 0:
            print('processing %d/%d (%.2f%% done)' % (i, N, i * 100.0 / N))
    print('wrote ', params['output_dir'])
    def forward(self, fc_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)

        outputs = []

        # Project the attention feats first to reduce memory and computation comsumptions.
        p_att_feats = self.ctx2att(att_feats.view(-1, self.att_feat_size))
        p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] +
                                         (self.att_hid_size, )))

        for i in range(seq.size(1) - 1):
            if self.training and i >= 1 and self.ss_prob > 0.0:  # otherwiste no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                    #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(
                        it, requires_grad=False) if utils.under_0_4() else it
            else:
                it = seq[:, i].clone()
            # break if all the sequences end
            if utils.under_0_4():
                if i >= 1 and seq[:, i].data.sum() == 0: break
            else:
                if i >= 1 and seq[:, i].sum() == 0: break

            xt = self.embed(it)

            output, state = self.core(xt, fc_feats, att_feats, p_att_feats,
                                      state)
            output = F.log_softmax(
                self.logit(output)) if utils.under_0_4() else F.log_softmax(
                    self.logit(output), dim=1)
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
Ejemplo n.º 10
0
    def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats,
                           tmp_p_att_feats, tmp_att_masks, state):
        # 'it' contains a word index
        xt = self.embed(
            Variable(it, requires_grad=False) if utils.under_0_4() else it)

        output, state = self.core(xt, tmp_fc_feats, tmp_att_feats,
                                  tmp_p_att_feats, state, tmp_att_masks)
        logprobs = torch.stack([
            F.softmax(m.logit(output[i]))
            if utils.under_0_4() else F.softmax(m.logit(output[i]), dim=1)
            for i, m in enumerate(self.models)
        ], 2).mean(2).log()

        return logprobs, state
Ejemplo n.º 11
0
    def sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
        beam_size = opt.get('beam_size', 10)
        batch_size = fc_feats.size(0)

        # Project the attention feats first to reduce memory and computation comsumptions.
        p_att_feats = self.ctx2att(att_feats.view(-1, self.att_feat_size))
        p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] +
                                         (self.att_hid_size, )))

        assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
        # lets process every image independently for now, for simplicity

        self.done_beams = [[] for _ in range(batch_size)]
        for k in range(batch_size):
            state = self.init_hidden(beam_size)
            tmp_fc_feats = fc_feats[k:k + 1].expand(beam_size,
                                                    self.fc_feat_size)
            tmp_att_feats = att_feats[k:k + 1].expand(
                *((beam_size, ) + att_feats.size()[1:])).contiguous()
            tmp_p_att_feats = p_att_feats[k:k + 1].expand(
                *((beam_size, ) + p_att_feats.size()[1:])).contiguous()

            for t in range(1):
                if t == 0:  # input <bos>
                    it = fc_feats.data.new(beam_size).long().zero_()
                    xt = self.embed(
                        Variable(it, requires_grad=False) if utils.under_0_4(
                        ) else it)

                output, state = self.core(xt, tmp_fc_feats, tmp_att_feats,
                                          tmp_p_att_feats, state)
                logprobs = F.log_softmax(self.logit(
                    output)) if utils.under_0_4() else F.log_softmax(
                        self.logit(output), dim=1)

            self.done_beams[k] = self.beam_search(state,
                                                  logprobs,
                                                  tmp_fc_feats,
                                                  tmp_att_feats,
                                                  tmp_p_att_feats,
                                                  opt=opt)
            seq[:, k] = self.done_beams[k][0][
                'seq']  # the first beam has highest cumulative score
            seqLogprobs[:, k] = self.done_beams[k][0]['logps']
        # return the samples and their log likelihoods
        return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
Ejemplo n.º 12
0
    def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
        beam_size = opt.get('beam_size', 10)
        batch_size = fc_feats.size(0)

        assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
        # lets process every image independently for now, for simplicity

        self.done_beams = [[] for _ in range(batch_size)]
        for k in range(batch_size):
            state = self.init_hidden(beam_size)
            for t in range(2):
                if t == 0:
                    xt = self.img_embed(fc_feats[k:k + 1]).expand(
                        beam_size, self.input_encoding_size)
                elif t == 1:  # input <bos>
                    it = fc_feats.data.new(beam_size).long().zero_()
                    xt = self.embed(it)

                output, state = self.core(xt, state)
                logprobs = F.log_softmax(self.logit(output), dim=1)

            self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
            seq[:, k] = self.done_beams[k][0][
                'seq']  # the first beam has highest cumulative score
            seqLogprobs[:, k] = self.done_beams[k][0]['logps']
        # return the samples and their log likelihoods
        return (Variable(seq.transpose(
            0, 1)), Variable(seqLogprobs.transpose(
                0, 1))) if utils.under_0_4() else (seq.transpose(0, 1),
                                                   seqLogprobs.transpose(0, 1))
Ejemplo n.º 13
0
    def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        outputs = []

        for i in range(seq.size(1)):
            if i == 0:
                xt = self.img_embed(fc_feats)
            else:
                if self.training and i >= 2 and self.ss_prob > 0.0:  # otherwiste no need to sample
                    sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                    sample_mask = sample_prob < self.ss_prob
                    if sample_mask.sum() == 0:
                        it = seq[:, i - 1].clone()
                    else:
                        sample_ind = sample_mask.nonzero().view(-1)
                        it = seq[:, i - 1].data.clone()
                        # prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                        # it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                        prob_prev = torch.exp(outputs[-1].data)  # fetch prev distribution: shape Nx(M+1)
                        it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
                        it = Variable(it, requires_grad=False) if utils.under_0_4() else it
                else:
                    it = seq[:, i - 1].clone()
                    # break if all the sequences end
                if i >= 2 and seq[:, i - 1].data.sum() == 0:
                    break
                xt = self.embed(it)

            output, state = self.core(xt.unsqueeze(0), state)
            output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
Ejemplo n.º 14
0
    def forward(self, xt, fc_feats, att_feats, state):
        att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size
        att = att_feats.view(-1, self.att_feat_size)
        if self.att_hid_size > 0:
            att = self.ctx2att(att)  # (batch * att_size) * att_hid_size
            att = att.view(
                -1, att_size,
                self.att_hid_size)  # batch * att_size * att_hid_size
            att_h = self.h2att(state[0][-1])  # batch * att_hid_size
            att_h = att_h.unsqueeze(1).expand_as(
                att)  # batch * att_size * att_hid_size
            dot = att + att_h  # batch * att_size * att_hid_size
            dot = F.tanh(dot)  # batch * att_size * att_hid_size
            dot = dot.view(
                -1, self.att_hid_size)  # (batch * att_size) * att_hid_size
            dot = self.alpha_net(dot)  # (batch * att_size) * 1
            dot = dot.view(-1, att_size)  # batch * att_size
        else:
            att = self.ctx2att(att)(att)  # (batch * att_size) * 1
            att = att.view(-1, att_size)  # batch * att_size
            att_h = self.h2att(state[0][-1])  # batch * 1
            att_h = att_h.expand_as(att)  # batch * att_size
            dot = att_h + att  # batch * att_size

        weight = F.softmax(dot) if utils.under_0_4() else F.softmax(dot, dim=1)
        att_feats_ = att_feats.view(
            -1, att_size,
            self.att_feat_size)  # batch * att_size * att_feat_size
        att_res = torch.bmm(weight.unsqueeze(1),
                            att_feats_).squeeze(1)  # batch * att_feat_size

        output, state = self.rnn(
            torch.cat([xt, att_res], 1).unsqueeze(0), state)
        return output.squeeze(0), state
        def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq,
                      beam_seq_logprobs, beam_logprobs_sum, state):
            # INPUTS:
            # logprobsf: probabilities augmented after diversity
            # beam_size: obvious
            # t        : time instant
            # beam_seq : tensor contanining the beams
            # beam_seq_logprobs: tensor contanining the beam logprobs
            # beam_logprobs_sum: tensor contanining joint logprobs
            # OUPUTS:
            # beam_seq : tensor containing the word indices of the decoded captions
            # beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            # beam_logprobs_sum : joint log-probability of each beam

            ys, ix = torch.sort(logprobsf, 1, True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols):  # for each column (word, essentially)
                for q in range(rows):  # for each beam expansion
                    # compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q, c] if utils.under_0_4() else ys[
                        q, c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    local_unaug_logprob = unaug_logprobsf[q, ix[q, c]]
                    candidates.append({
                        'c': ix[q, c],
                        'q': q,
                        'p': candidate_logprob,
                        'r': local_unaug_logprob
                    })
            candidates = sorted(candidates, key=lambda x: -x['p'])

            new_state = [_.clone() for _ in state]
            # beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
                # we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                # fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:,
                                                                        v['q']]
                # rearrange recurrent states
                for state_ix in range(len(new_state)):
                    #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v[
                        'q']]  # dimension one is time step
                # append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c']  # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r']  # the raw logprob here
                beam_logprobs_sum[vix] = v[
                    'p']  # the new (sum) logprob along this beam
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates
Ejemplo n.º 16
0
    def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state):
        # 'it' contains a word index
        xt = self.embed(it)

        output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
        logprobs = F.log_softmax(self.logit(output)) if utils.under_0_4() else F.log_softmax(self.logit(output), dim=1)

        return logprobs, state
Ejemplo n.º 17
0
def get_self_critical_reward(model, fc_feats, attri_feats, att_feats, att_masks, data, gen_result, opt):
    batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
    seq_per_img = batch_size // len(data['gts'])
    
    # get greedy decoding baseline
    model.eval()
    greedy_res, _ = model(Variable(fc_feats.data, volatile=True) if utils.under_0_4() else fc_feats,
                          Variable(attri_feats.data, volatile=True) if utils.under_0_4() else attri_feats,
                          Variable(att_feats.data, volatile=True) if utils.under_0_4() else att_feats,
                          att_masks=Variable(att_masks.data, volatile=True) if utils.under_0_4() else att_masks, mode='sample')
    model.train()

    res = OrderedDict()
    
    gen_result = gen_result.data.cpu().numpy()
    greedy_res = greedy_res.data.cpu().numpy()
    for i in range(batch_size):
        res[i] = [array_to_str(gen_result[i])]
    for i in range(batch_size):
        res[batch_size + i] = [array_to_str(greedy_res[i])]

    gts = OrderedDict()
    for i in range(len(data['gts'])):
        gts[i] = [array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i]))]

    res_ = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)]
    res__ = {i: res[i] for i in range(2 * batch_size)}
    gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)}
    if opt.cider_reward_weight > 0:
        _, cider_scores = CiderD_scorer.compute_score(gts, res_)
        #print('Cider scores:', _)
    else:
        cider_scores = 0
    if opt.bleu_reward_weight > 0:
        _, bleu_scores = Bleu_scorer.compute_score(gts, res__)
        bleu_scores = np.array(bleu_scores[3])
        #print('Bleu scores:', _[3])
    else:
        bleu_scores = 0
    scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores

    scores = scores[:batch_size] - scores[batch_size:]

    rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)

    return rewards
Ejemplo n.º 18
0
    def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
        beam_size = opt.get('beam_size', 10)
        batch_size = fc_feats.size(0)

        fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(
            fc_feats, att_feats, att_masks)

        assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
        # lets process every image independently for now, for simplicity

        self.done_beams = [[] for _ in range(batch_size)]
        for k in range(batch_size):
            state = self.init_hidden(beam_size)
            tmp_fc_feats = [
                fc_feats[i][k:k + 1].expand(beam_size, fc_feats[i].size(1))
                for i, m in enumerate(self.models)
            ]
            tmp_att_feats = [
                att_feats[i][k:k + 1].expand(
                    *((beam_size, ) + att_feats[i].size()[1:])).contiguous()
                for i, m in enumerate(self.models)
            ]
            tmp_p_att_feats = [
                p_att_feats[i][k:k + 1].expand(
                    *((beam_size, ) + p_att_feats[i].size()[1:])).contiguous()
                for i, m in enumerate(self.models)
            ]
            tmp_att_masks = [
                att_masks[k:k + 1].expand(
                    *((beam_size, ) + att_masks.size()[1:])).contiguous()
                for i, m in enumerate(self.models)
            ] if att_masks[0] is not None else att_masks

            it = fc_feats[0].data.new(beam_size).long().zero_()
            logprobs, state = self.get_logprobs_state(it, tmp_fc_feats,
                                                      tmp_att_feats,
                                                      tmp_p_att_feats,
                                                      tmp_att_masks, state)

            self.done_beams[k] = self.beam_search(state,
                                                  logprobs,
                                                  tmp_fc_feats,
                                                  tmp_att_feats,
                                                  tmp_p_att_feats,
                                                  tmp_att_masks,
                                                  opt=opt)
            seq[:, k] = self.done_beams[k][0][
                'seq']  # the first beam has highest cumulative score
            seqLogprobs[:, k] = self.done_beams[k][0]['logps']
        # return the samples and their log likelihoods
        if utils.under_0_4():
            Variable(seq.transpose(0,
                                   1)), Variable(seqLogprobs.transpose(0, 1))
        else:
            return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
Ejemplo n.º 19
0
def pack_wrapper(module, att_feats, att_masks):
    if att_masks is not None:
        if utils.under_0_4():
            packed = pack_padded_sequence(att_feats, list(att_masks.data.long().sum(1)), batch_first=True)
            return pad_packed_sequence(PackedSequence(module(packed[0]), packed[1]), batch_first=True)[0]
        else:
            packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
            return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
    else:
        return module(att_feats)
Ejemplo n.º 20
0
    def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state):
        # 'it' is Variable contraining a word index
        xt = self.embed(it)

        output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
        logprobs = F.log_softmax(self.logit(
            self.dropout(output))) if utils.under_0_4() else F.log_softmax(
                self.logit(self.dropout(output)), dim=1)

        return logprobs, state
Ejemplo n.º 21
0
    def _forward(self, fc_feats, attri_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)

        outputs = Variable(fc_feats.data.new(batch_size, seq.size(1) - 1, self.vocab_size+1).zero_()) if utils.under_0_4() else fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size + 1)

        # Prepare the features
        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
        # pp_att_feats is used for attention, we cache it in advance to reduce computation cost

        for i in range(seq.size(1) - 1):
            if self.training and i >= 1 and self.ss_prob > 0.0:  # otherwiste no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) if utils.under_0_4() else fc_feats.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    # prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                    # it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                    # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
                    prob_prev = torch.exp(outputs[:, i-1].data if utils.under_0_4() else outputs[:, i - 1].detach())  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
                    it = Variable(it, requires_grad=False) if utils.under_0_4() else it
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if utils.under_0_4():
                if i >= 1 and seq[:, i].data.sum() == 0: break
            else:
                if i >= 1 and seq[:, i].sum() == 0: break

            output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
            outputs[:, i] = output

        return outputs
Ejemplo n.º 22
0
    def _prepare_feature(self, att_feats, att_masks=None, seq=None):
        att_feats, att_masks = self.clip_att(att_feats, att_masks)

        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)

        if att_masks is None:
            att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
        att_masks = att_masks.unsqueeze(-2)

        if seq is not None:
            # crop the last one
            seq = seq[:, :-1]
            seq_mask = (seq.data > 0)
            seq_mask[:, 0] += 1

            seq_mask = seq_mask.unsqueeze(-2)
            seq_mask = (seq_mask & subsequent_mask(seq.size(-1)).cuda()) if utils.under_0_4() else (seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask))
        else:
            seq_mask = None

        seq_mask = Variable(seq_mask, requires_grad=False) if utils.under_0_4() else seq_mask

        return att_feats, seq, att_masks, seq_mask
Ejemplo n.º 23
0
    def _sample_(self, fc_feats, attri_feats, att_feats, att_masks=None, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        decoding_constraint = opt.get('decoding_constraint', 0)
        if beam_size > 1:
            return self._sample_beam(fc_feats, att_feats, att_masks, opt)

        if sample_max:
            with torch.no_grad():
                seq_, seqLogprobs_ = self._sample_(fc_feats, att_feats, att_masks, opt)

        batch_size = att_feats.shape[0]

        att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks)

        memory = self.model.encode(att_feats, att_masks)
        ys = Variable(torch.zeros((batch_size, 1)).long().zero_()).cuda() if utils.under_0_4() else torch.zeros((batch_size, 1), dtype=torch.long).to(att_feats.device)

        seq =  Variable(att_feats.data.new(batch_size, self.seq_length).long().zero_()) if utils.under_0_4() else att_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long)
        seqLogprobs =Variable(att_feats.data.new(batch_size, self.seq_length).long().zero_()) if utils.under_0_4() else  att_feats.new_zeros(batch_size, self.seq_length)

        for i in range(self.seq_length):
            out = self.model.decode(memory, att_masks, ys, subsequent_mask(ys.size(1)) if utils.under_0_4() else subsequent_mask(ys.size(1)).to(att_feats.device))
            logprob = self.model.generator(out[:, -1])
            if sample_max:
                sampleLogprobs, next_word = torch.max(logprob, dim=1)
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprob.data)  # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprob.data, temperature))
                next_word = torch.multinomial(prob_prev, 1)
                sampleLogprobs = logprobs.gather(1, next_word)  # gather the logprobs at sampled positions

            seq[:, i] = next_word
            seqLogprobs[:, i] = sampleLogprobs
            ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
        assert (seq * ((seq_ > 0).long()) == seq_).all(), 'seq doens\'t match'
        assert (seqLogprobs * ((seq_ > 0).float()) - seqLogprobs_ * ((seq_ > 0).float())).abs().max() < 1e-5, 'logprobs doens\'t match'
        return seq, seqLogprobs
Ejemplo n.º 24
0
    def forward(self, h, att_feats, p_att_feats, att_masks=None):
        # The p_att_feats here is already projected
        att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
        att = p_att_feats.view(-1, att_size, self.att_hid_size)

        att_h = self.h2att(h)  # batch * att_hid_size
        att_h = att_h.unsqueeze(1).expand_as(att)  # batch * att_size * att_hid_size
        dot = att + att_h  # batch * att_size * att_hid_size
        dot = F.tanh(dot)  # batch * att_size * att_hid_size
        dot = dot.view(-1, self.att_hid_size)  # (batch * att_size) * att_hid_size
        dot = self.alpha_net(dot)  # (batch * att_size) * 1
        dot = dot.view(-1, att_size)  # batch * att_size

        weight = F.softmax(dot) if utils.under_0_4() else F.softmax(dot, dim=1)  # batch * att_size
        if att_masks is not None:
            weight = weight * att_masks.view(-1, att_size).float()
            weight = weight / weight.sum(1, keepdim=True)  # normalize to 1
        att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1))  # batch * att_size * att_feat_size
        att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1)  # batch * att_feat_size

        return att_res
Ejemplo n.º 25
0
def eval_split(opt, loader, i2t_model, nmt_model, eval_kwargs={}):
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 1)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images',
                                 eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)

    if opt.coco_eval_flag:
        return eval_split_coco_unpaired(opt, loader, i2t_model, nmt_model,
                                        eval_kwargs)

    # Make sure in the evaluation mode
    print('Start evaluate the model ...')
    if opt.i2t_eval_flag:
        i2t_crit = criterion.LanguageModelCriterion(opt)
        i2t_model.eval()

    if opt.nmt_eval_flag:
        nmt_crit = criterion.NMT_loss(opt,
                                      nmt_model.generator,
                                      criterion.NMTCriterion(
                                          loader.nmt_dicts['tgt'].size(), opt),
                                      eval=True)
        nmt_model.eval()

    loader.reset_iterator(split)
    beam_accum = {
        "predicted_ids": [],
        "beam_parent_ids": [],
        "scores": [],
        "log_probs": []
    }

    n = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    if opt.i2t_eval_flag:
        while True:
            data = loader.get_batch(split)
            n = n + loader.batch_size

            if data.get('labels', None) is not None and verbose_loss:
                # forward the model to get loss
                tmp = [
                    data['fc_feats'], data['attri_feats'], data['att_feats'],
                    data['labels'], data['masks'], data['att_masks']
                ]
                tmp = [
                    _ if _ is None else
                    (Variable(torch.from_numpy(_), volatile=True).cuda()
                     if utils.under_0_4() else torch.from_numpy(_).cuda())
                    for _ in tmp
                ]
                fc_feats, attri_feats, att_feats, labels, masks, att_masks = tmp
                outputs = i2t_model(fc_feats, attri_feats, att_feats, labels,
                                    att_masks)
                loss = i2t_crit(outputs, labels[:, 1:], masks[:, 1:]).data[0]
                loss_sum = loss_sum + loss
                loss_evals = loss_evals + 1

            # forward the model to also get generated samples for each image
            # Only leave one feature for each image, in case duplicate sample
            tmp = [
                data['fc_feats'][np.arange(loader.batch_size) *
                                 loader.seq_per_img],
                data['attri_feats'][np.arange(loader.batch_size) *
                                    loader.seq_per_img],
                data['att_feats'][np.arange(loader.batch_size) *
                                  loader.seq_per_img],
                data['att_masks'][np.arange(loader.batch_size) *
                                  loader.seq_per_img]
                if data['att_masks'] is not None else None
            ]
            tmp = [
                _ if _ is None else
                (Variable(torch.from_numpy(_), volatile=True).cuda()
                 if utils.under_0_4() else torch.from_numpy(_).cuda())
                for _ in tmp
            ]
            fc_feats, attri_feats, att_feats, att_masks = tmp
            # forward the model to also get generated samples for each image
            seq = i2t_model(fc_feats,
                            attri_feats,
                            att_feats,
                            att_masks,
                            opt=eval_kwargs,
                            mode='sample')[0].data
            #print(seq)
            # Print beam search
            if beam_size > 1 and verbose_beam:
                for i in range(loader.batch_size):
                    print('\n'.join([
                        utils.decode_sequence(loader.get_vocab(),
                                              _['seq'].unsqueeze(0))[0]
                        for _ in i2t_model.done_beams[i]
                    ]))
                    print('--' * 10)
            sents = utils.decode_sequence(loader.get_vocab(), seq)
            tgtBatch = []

            for k, sent in enumerate(sents):
                if verbose:
                    print('image %s: ' % (data['infos'][k]['id']),
                          sent.encode('utf8', 'replace'))
                entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
                if eval_kwargs.get('dump_path', 0) == 1:
                    entry['file_name'] = data['infos'][k]['file_path']
                predictions.append(entry)
                if eval_kwargs.get('dump_images', 0) == 1:
                    # dump the raw image to vis/ folder
                    cmd = 'cp "' + os.path.join(
                        eval_kwargs['image_root'], data['infos'][k]
                        ['file_path']) + '" vis/imgs/img' + str(
                            len(predictions)) + '.jpg'  # bit gross
                    print(cmd)
                    os.system(cmd)

            # if we wrapped around the split or used up val imgs budget then bail
            ix0 = data['bounds']['it_pos_now']
            ix1 = data['bounds']['it_max']
            if num_images != -1:
                ix1 = min(ix1, num_images)
            for i in range(n - ix1):
                predictions.pop()

            if verbose:
                print('evaluating validation preformance... %d/%d (%f)' %
                      (ix0 - 1, ix1, loss))

            if data['bounds']['wrapped']:
                break
            if num_images >= 0 and n >= num_images:
                break

        lang_stats = None
        if lang_eval == 1:
            if 'coco' in opt.input_json:
                lang_stats = language_eval('coco', predictions, opt.id, split)
            elif 'chinese' in opt.input_json:
                lang_stats = language_eval('zh', predictions, opt.id, split)
            elif '30k' in opt.input_json:
                lang_stats = language_eval('30k', predictions, opt.id, split)
            else:
                raise Exception('Current eval type is not recognizable.')
    # Switch back to training mode
    if opt.nmt_eval_flag:
        for i in tqdm(range(int(loader.nmt_validData.numBatches))):
            batch = loader.get_batch('val')
            outputs, attn, dec_hidden, _ = nmt_model(batch['nmt'].src,
                                                     batch['nmt'].tgt,
                                                     batch['nmt'].lengths)
            batch_loss = nmt_crit(loader, batch['nmt'], outputs, attn)

    if opt.nmt_train_flag: nmt_model.train()
    if opt.i2t_train_flag: i2t_model.train()

    if opt.i2t_eval_flag and opt.nmt_eval_flag:
        return loss_sum / loss_evals, predictions, lang_stats, nmt_crit.total_stats.ppl(
        ), nmt_crit.total_stats.accuracy()
    elif opt.nmt_eval_flag:
        return 0.0, None, None, nmt_crit.total_stats.ppl(
        ), nmt_crit.total_stats.accuracy()
    elif opt.i2t_eval_flag:
        return loss_sum / loss_evals, predictions, lang_stats, 0.0, 0.0
Ejemplo n.º 26
0
 def forward(self, x):
     x = (x + Variable(self.pe[:, :x.size(1)], requires_grad=False)) if utils.under_0_4() else (x + self.pe[:, :x.size(1)])
     return self.dropout(x)
    def beam_search(self, init_state, init_logprobs, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda,
                          bdash):
            local_time = t - divm
            unaug_logprobsf = logprobsf.clone()
            for prev_choice in range(divm):
                prev_decisions = beam_seq_table[prev_choice][local_time]
                for sub_beam in range(bdash):
                    for prev_labels in range(bdash):
                        logprobsf[sub_beam][
                            prev_decisions[prev_labels]] = logprobsf[sub_beam][
                                prev_decisions[prev_labels]] - diversity_lambda
            return unaug_logprobsf

        # does one step of classical beam search

        def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq,
                      beam_seq_logprobs, beam_logprobs_sum, state):
            # INPUTS:
            # logprobsf: probabilities augmented after diversity
            # beam_size: obvious
            # t        : time instant
            # beam_seq : tensor contanining the beams
            # beam_seq_logprobs: tensor contanining the beam logprobs
            # beam_logprobs_sum: tensor contanining joint logprobs
            # OUPUTS:
            # beam_seq : tensor containing the word indices of the decoded captions
            # beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            # beam_logprobs_sum : joint log-probability of each beam

            ys, ix = torch.sort(logprobsf, 1, True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols):  # for each column (word, essentially)
                for q in range(rows):  # for each beam expansion
                    # compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q, c] if utils.under_0_4() else ys[
                        q, c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    local_unaug_logprob = unaug_logprobsf[q, ix[q, c]]
                    candidates.append({
                        'c': ix[q, c],
                        'q': q,
                        'p': candidate_logprob,
                        'r': local_unaug_logprob
                    })
            candidates = sorted(candidates, key=lambda x: -x['p'])

            new_state = [_.clone() for _ in state]
            # beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
                # we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                # fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:,
                                                                        v['q']]
                # rearrange recurrent states
                for state_ix in range(len(new_state)):
                    #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v[
                        'q']]  # dimension one is time step
                # append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c']  # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r']  # the raw logprob here
                beam_logprobs_sum[vix] = v[
                    'p']  # the new (sum) logprob along this beam
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates

        # Start diverse_beam_search
        opt = kwargs['opt']
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        max_ppl = opt.get('max_ppl', 0)
        bdash = beam_size // group_size  # beam per group

        # INITIALIZATIONS
        beam_seq_table = [
            torch.LongTensor(self.seq_length, bdash).zero_()
            for _ in range(group_size)
        ]
        beam_seq_logprobs_table = [
            torch.FloatTensor(self.seq_length, bdash).zero_()
            for _ in range(group_size)
        ]
        beam_logprobs_sum_table = [
            torch.zeros(bdash) for _ in range(group_size)
        ]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[] for _ in range(group_size)]
        state_table = [
            list(torch.unbind(_))
            for _ in torch.stack(init_state).chunk(group_size, 2)
        ]
        logprobs_table = list(init_logprobs.chunk(group_size, 0))
        # END INIT

        # Chunk elements in the args
        args = list(args)
        args = [
            _.chunk(group_size) if _ is not None else [None] * group_size
            for _ in args
        ]
        args = [[args[i][j] for i in range(len(args))]
                for j in range(group_size)]

        for t in range(self.seq_length + group_size - 1):
            for divm in range(group_size):
                if t >= divm and t <= self.seq_length + divm - 1:
                    # add diversity
                    logprobsf = logprobs_table[divm].data.float()
                    # suppress previous word
                    if decoding_constraint and t - divm > 0:
                        logprobsf.scatter_(
                            1, beam_seq_table[divm][t - divm -
                                                    1].unsqueeze(1).cuda(),
                            float('-inf'))
                    # suppress UNK tokens in the decoding
                    logprobsf[:, logprobsf.size(1) -
                              1] = logprobsf[:, logprobsf.size(1) - 1] - 1000
                    # diversity is added here
                    # the function directly modifies the logprobsf values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    unaug_logprobsf = add_diversity(beam_seq_table, logprobsf,
                                                    t, divm, diversity_lambda,
                                                    bdash)

                    # infer new beams
                    beam_seq_table[divm], \
                    beam_seq_logprobs_table[divm], \
                    beam_logprobs_sum_table[divm], \
                    state_table[divm], \
                    candidates_divm = beam_step(logprobsf,
                                                unaug_logprobsf,
                                                bdash,
                                                t - divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for vix in range(bdash):
                        if beam_seq_table[divm][
                                t - divm,
                                vix] == 0 or t == self.seq_length + divm - 1:
                            final_beam = {
                                'seq':
                                beam_seq_table[divm][:, vix].clone(),
                                'logps':
                                beam_seq_logprobs_table[divm][:, vix].clone(),
                                'unaug_p':
                                beam_seq_logprobs_table[divm][:, vix].sum()
                                if utils.under_0_4() else
                                beam_seq_logprobs_table[divm]
                                [:, vix].sum().item(),
                                'p':
                                beam_logprobs_sum_table[divm][vix]
                                if utils.under_0_4() else
                                beam_logprobs_sum_table[divm][vix].item()
                            }
                            if max_ppl:
                                final_beam['p'] = final_beam['p'] / (t - divm +
                                                                     1)
                            done_beams_table[divm].append(final_beam)
                            # don't continue beams from finished sequences
                            beam_logprobs_sum_table[divm][vix] = -1000

                    # move the current group one step forward in time

                    it = beam_seq_table[divm][t - divm]
                    logprobs_table[divm], state_table[
                        divm] = self.get_logprobs_state(
                            Variable(it.cuda())
                            if utils.under_0_4() else it.cuda(),
                            *(args[divm] + [state_table[divm]]))

        # all beams are sorted by their log-probabilities
        done_beams_table = [
            sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash]
            for i in range(group_size)
        ]
        done_beams = reduce(lambda a, b: a + b, done_beams_table)
        return done_beams
Ejemplo n.º 28
0
    def get_logprobs_state(self, it, memory, mask, state):
        """
        state = [ys.unsqueeze(0)]
        """
        if state is None:
            ys = it.unsqueeze(1)
        else:
            ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
        out = self.model.decode(memory, mask, ys, Variable(subsequent_mask(ys.size(1)), requires_grad=False).cuda() if utils.under_0_4() else subsequent_mask(ys.size(1)).to(memory.device))
        logprobs = self.model.generator(out[:, -1])

        return logprobs, [ys.unsqueeze(0)]
Ejemplo n.º 29
0
    def train(self, data, loader, iteration, epoch, nmt_epoch):
        nmt_dec_state = None
        nmt_dec_state_zh = None
        torch.cuda.synchronize()
        self.optim.zero_grad()

        tmp = [
            data['fc_feats'], data['attri_feats'], data['att_feats'],
            data['labels'], data['masks'], data['att_masks'],
            data['nmt'] if self.nmt_train_flag else None
        ]
        tmp = [
            _ if _ is None else
            (Variable(torch.from_numpy(_), requires_grad=False).cuda()
             if utils.under_0_4() else torch.from_numpy(_).cuda()) for _ in tmp
        ]
        fc_feats, attri_feats, att_feats, labels, masks, att_masks, nmt_batch = tmp

        if self.i2t_train_flag:
            if self.update_i2t_lr_flag:
                self.optim.update_LearningRate(
                    'i2t', epoch)  # Assign the learning rate
                self.optim.update_ScheduledSampling_prob(
                    self.opt, epoch,
                    self.dp_i2t_model)  # Assign the scheduled sampling prob
                if self.opt.self_critical_after != -1 and epoch >= self.opt.self_critical_after:
                    # If start self critical training
                    self.sc_flag = True
                    init_scorer(self.opt.cached_tokens)
                else:
                    self.sc_flag = False
                self.update_i2t_lr_flag = False

            if not self.sc_flag:
                i2t_outputs = self.dp_i2t_model(fc_feats, attri_feats,
                                                att_feats, labels, att_masks)
                i2t_loss = self.i2t_crit(i2t_outputs, labels[:, 1:], masks[:,
                                                                           1:])
            else:
                gen_result, sample_logprobs = self.dp_i2t_model(
                    fc_feats,
                    attri_feats,
                    att_feats,
                    att_masks,
                    opt={'sample_max': 0},
                    mode='sample')
                reward = get_self_critical_reward(self.dp_i2t_model, fc_feats,
                                                  attri_feats, att_feats,
                                                  att_masks, data, gen_result,
                                                  self.opt)
                i2t_loss = self.i2t_rl_crit(
                    sample_logprobs, gen_result.data,
                    Variable(torch.from_numpy(reward).float().cuda(),
                             requires_grad=False))

                self.i2t_avg_reward = np.mean(reward[:, 0])
            self.i2t_train_loss = i2t_loss.data[0] if utils.under_0_4(
            ) else i2t_loss.item()
            i2t_loss.backward(retain_graph=True)

        if self.nmt_train_flag:
            if self.update_nmt_lr_flag:
                self.optim.update_LearningRate(
                    'nmt', nmt_epoch)  # Assign the learning rate
            outputs, attn, dec_state, upper_bounds = self.dp_nmt_model(
                nmt_batch.src, nmt_batch.tgt, nmt_batch.lengths, nmt_dec_state)
            nmt_loss = self.nmt_crit(loader, nmt_batch, outputs, attn)

            if nmt_dec_state is not None: nmt_dec_state.detach()
            if nmt_dec_state_zh is not None: nmt_dec_state_zh.detach()

            self.nmt_crit.report_stats.n_src_words += nmt_batch.lengths.data.sum(
            )
            self.nmt_train_ppl = self.nmt_crit.report_stats.ppl()
            self.nmt_train_acc = self.nmt_crit.report_stats.accuracy()
            # Minimize the word embedding weights
            # wemb_weight_loss = self.weight_trans(self.i2t_model.embed, self.nmt_encoder.embeddings.word_lut)
            # self.wemb_loss = wemb_weight_loss.data[0]

            nmt_loss.backward(retain_graph=True)
        # if self.nmt_train_flag: wemb_weight_loss.backward(retain_graph=True)
        self.optim.step()