コード例 #1
0
    def beam_search(self, init_state, init_logprobs, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda,
                          bdash):
            local_time = t - divm
            unaug_logprobsf = logprobsf.clone()
            for prev_choice in range(divm):
                prev_decisions = beam_seq_table[prev_choice][local_time]
                for sub_beam in range(bdash):
                    for prev_labels in range(bdash):
                        logprobsf[sub_beam][
                            prev_decisions[prev_labels]] = logprobsf[sub_beam][
                                prev_decisions[prev_labels]] - diversity_lambda
            return unaug_logprobsf

        # does one step of classical beam search

        def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq,
                      beam_seq_logprobs, beam_logprobs_sum, state):
            #INPUTS:
            #logprobsf: probabilities augmented after diversity
            #beam_size: obvious
            #t        : time instant
            #beam_seq : tensor contanining the beams
            #beam_seq_logprobs: tensor contanining the beam logprobs
            #beam_logprobs_sum: tensor contanining joint logprobs
            #OUPUTS:
            #beam_seq : tensor containing the word indices of the decoded captions
            #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            #beam_logprobs_sum : joint log-probability of each beam

            ys, ix = torch.sort(logprobsf, 1, True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols):  # for each column (word, essentially)
                for q in range(rows):  # for each beam expansion
                    #compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q, c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    local_unaug_logprob = unaug_logprobsf[q, ix[q, c]]
                    candidates.append({
                        'c': ix[q, c],
                        'q': q,
                        'p': candidate_logprob,
                        'r': local_unaug_logprob
                    })
            candidates = sorted(candidates, key=lambda x: -x['p'])

            new_state = [_.clone() for _ in state]
            #beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
                #we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                #fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:,
                                                                        v['q']]
                #rearrange recurrent states
                for state_ix in range(len(new_state)):
                    #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v[
                        'q']]  # dimension one is time step
                #append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c']  # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r']  # the raw logprob here
                beam_logprobs_sum[vix] = v[
                    'p']  # the new (sum) logprob along this beam
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates

        # Start diverse_beam_search
        opt = kwargs['opt']
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        max_ppl = opt.get('max_ppl', 0)
        length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
        bdash = beam_size // group_size  # beam per group

        # INITIALIZATIONS
        beam_seq_table = [
            torch.LongTensor(self.seq_length, bdash).zero_()
            for _ in range(group_size)
        ]
        beam_seq_logprobs_table = [
            torch.FloatTensor(self.seq_length, bdash).zero_()
            for _ in range(group_size)
        ]
        beam_logprobs_sum_table = [
            torch.zeros(bdash) for _ in range(group_size)
        ]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[] for _ in range(group_size)]
        state_table = [
            list(torch.unbind(_))
            for _ in torch.stack(init_state).chunk(group_size, 2)
        ]
        logprobs_table = list(init_logprobs.chunk(group_size, 0))
        # END INIT

        # Chunk elements in the args
        args = list(args)
        args = [
            _.chunk(group_size) if _ is not None else [None] * group_size
            for _ in args
        ]
        args = [[args[i][j] for i in range(len(args))]
                for j in range(group_size)]

        for t in range(self.seq_length + group_size - 1):
            for divm in range(group_size):
                if t >= divm and t <= self.seq_length + divm - 1:
                    # add diversity
                    logprobsf = logprobs_table[divm].data.float()
                    # suppress previous word
                    if decoding_constraint and t - divm > 0:
                        logprobsf.scatter_(
                            1, beam_seq_table[divm][t - divm -
                                                    1].unsqueeze(1).cuda(),
                            float('-inf'))
                    # suppress UNK tokens in the decoding
                    logprobsf[:, logprobsf.size(1) -
                              1] = logprobsf[:, logprobsf.size(1) - 1] - 1000
                    # diversity is added here
                    # the function directly modifies the logprobsf values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    unaug_logprobsf = add_diversity(beam_seq_table, logprobsf,
                                                    t, divm, diversity_lambda,
                                                    bdash)

                    # infer new beams
                    beam_seq_table[divm],\
                    beam_seq_logprobs_table[divm],\
                    beam_logprobs_sum_table[divm],\
                    state_table[divm],\
                    candidates_divm = beam_step(logprobsf,
                                                unaug_logprobsf,
                                                bdash,
                                                t-divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for vix in range(bdash):
                        if beam_seq_table[divm][
                                t - divm,
                                vix] == 0 or t == self.seq_length + divm - 1:
                            final_beam = {
                                'seq':
                                beam_seq_table[divm][:, vix].clone(),
                                'logps':
                                beam_seq_logprobs_table[divm][:, vix].clone(),
                                'unaug_p':
                                beam_seq_logprobs_table[divm]
                                [:, vix].sum().item(),
                                'p':
                                beam_logprobs_sum_table[divm][vix].item()
                            }
                            final_beam['p'] = length_penalty(
                                t - divm + 1, final_beam['p'])
                            # if max_ppl:
                            #     final_beam['p'] = final_beam['p'] / (t-divm+1)
                            done_beams_table[divm].append(final_beam)
                            # don't continue beams from finished sequences
                            beam_logprobs_sum_table[divm][vix] = -1000

                    # move the current group one step forward in time

                    it = beam_seq_table[divm][t - divm]
                    logprobs_table[divm], state_table[
                        divm] = self.get_logprobs_state(
                            it.cuda(), *(args[divm] + [state_table[divm]]))

        # all beams are sorted by their log-probabilities
        done_beams_table = [
            sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash]
            for i in range(group_size)
        ]
        done_beams = reduce(lambda a, b: a + b, done_beams_table)
        return done_beams
コード例 #2
0
    def beam_search(self, init_state, init_logprobs, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda,
                          bdash):
            local_time = t - divm
            unaug_logprobs = logprobs.clone()
            batch_size = beam_seq_table[0].shape[0]

            if divm > 0:
                change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
                for prev_choice in range(divm):
                    prev_decisions = beam_seq_table[
                        prev_choice][:, :, local_time]  # Nxb
                    for prev_labels in range(bdash):
                        change.scatter_add_(
                            1, prev_decisions[:, prev_labels].unsqueeze(-1),
                            change.new_ones(batch_size, 1))

                if local_time == 0:
                    logprobs = logprobs - change * diversity_lambda
                else:
                    logprobs = logprobs - self.repeat_tensor(
                        bdash, change) * diversity_lambda

            return logprobs, unaug_logprobs

        # does one step of classical beam search

        def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq,
                      beam_seq_logprobs, beam_logprobs_sum, state):
            #INPUTS:
            #logprobs: probabilities augmented after diversity N*bxV
            #beam_size: obvious
            #t        : time instant
            #beam_seq : tensor contanining the beams
            #beam_seq_logprobs: tensor contanining the beam logprobs
            #beam_logprobs_sum: tensor contanining joint logprobs
            #OUPUTS:
            #beam_seq : tensor containing the word indices of the decoded captions Nxbxl
            #beam_seq_logprobs : log-probability of each decision made, NxbxlxV
            #beam_logprobs_sum : joint log-probability of each beam Nxb

            batch_size = beam_logprobs_sum.shape[0]
            vocab_size = logprobs.shape[-1]
            logprobs = logprobs.reshape(batch_size, -1, vocab_size)  # NxbxV
            if t == 0:
                assert logprobs.shape[1] == 1
                beam_logprobs_sum = beam_logprobs_sum[:, :1]
            candidate_logprobs = beam_logprobs_sum.unsqueeze(
                -1) + logprobs  # beam_logprobs_sum Nxb logprobs is NxbxV
            ys, ix = torch.sort(
                candidate_logprobs.reshape(candidate_logprobs.shape[0], -1),
                -1, True)
            ys, ix = ys[:, :beam_size], ix[:, :beam_size]
            beam_ix = ix // vocab_size  # Nxb which beam
            selected_ix = ix % vocab_size  # Nxb # which world
            state_ix = (
                beam_ix +
                torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) *
                logprobs.shape[1]).reshape(-1)  # N*b which in Nxb beams

            if t > 0:
                # gather according to beam_ix
                assert (beam_seq.gather(
                    1,
                    beam_ix.unsqueeze(-1).
                    expand_as(beam_seq)) == beam_seq.reshape(
                        -1,
                        beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
                beam_seq = beam_seq.gather(
                    1,
                    beam_ix.unsqueeze(-1).expand_as(beam_seq))

                beam_seq_logprobs = beam_seq_logprobs.gather(
                    1,
                    beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(
                        beam_seq_logprobs))

            beam_seq = torch.cat(
                [beam_seq, selected_ix.unsqueeze(-1)], -1)  # beam_seq Nxbxl
            beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
                logprobs.reshape(batch_size, -1).gather(1, ix)
            assert (beam_logprobs_sum == ys).all()
            _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(
                batch_size, -1, vocab_size)
            beam_logprobs = unaug_logprobs.reshape(
                batch_size, -1, vocab_size).gather(
                    1,
                    beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size))  # NxbxV
            assert (_tmp_beam_logprobs == beam_logprobs).all()
            beam_seq_logprobs = torch.cat([
                beam_seq_logprobs,
                beam_logprobs.reshape(batch_size, -1, 1, vocab_size)
            ], 2)

            new_state = [None for _ in state]
            for _ix in range(len(new_state)):
                #  copy over state in previous beam q to new beam at vix
                new_state[_ix] = state[_ix][:, state_ix]
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state

        # Start diverse_beam_search
        opt = kwargs['opt']
        temperature = opt.get(
            'temperature',
            1)  # This should not affect beam search, but will affect dbs
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        remove_bad_endings = opt.get('remove_bad_endings', 0)
        suppress_UNK = opt.get('suppress_UNK', 0)
        length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
        bdash = beam_size // group_size  # beam per group

        batch_size = init_logprobs.shape[0]
        device = init_logprobs.device
        # INITIALIZATIONS
        beam_seq_table = [
            torch.LongTensor(batch_size, bdash, 0).to(device)
            for _ in range(group_size)
        ]
        beam_seq_logprobs_table = [
            torch.FloatTensor(batch_size, bdash, 0,
                              self.vocab_size + 1).to(device)
            for _ in range(group_size)
        ]
        beam_logprobs_sum_table = [
            torch.zeros(batch_size, bdash).to(device)
            for _ in range(group_size)
        ]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[[] for __ in range(group_size)]
                            for _ in range(batch_size)]
        # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
        # state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
        state_table = [[_.clone() for _ in init_state]
                       for _ in range(group_size)]
        # logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
        logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
        # END INIT

        # Chunk elements in the args
        args = list(args)
        args = utils.split_tensors(
            group_size, args)  # For each arg, turn (Bbg)x... to (Bb)x(g)x...
        if self.__class__.__name__ == 'AttEnsemble':
            args = [[[args[j][i][k] for i in range(len(self.models))]
                     for j in range(len(args))] for k in range(group_size)
                    ]  # group_name, arg_name, model_name
        else:
            args = [[args[i][j] for i in range(len(args))]
                    for j in range(group_size)]

        for t in range(self.seq_length + group_size - 1):
            for divm in range(group_size):
                if t >= divm and t <= self.seq_length + divm - 1:
                    # add diversity
                    logprobs = logprobs_table[divm]
                    # suppress previous word
                    if decoding_constraint and t - divm > 0:
                        logprobs.scatter_(
                            1, beam_seq_table[divm][:, :,
                                                    t - divm - 1].reshape(
                                                        -1, 1).to(device),
                            float('-inf'))
                    if remove_bad_endings and t - divm > 0:
                        logprobs[torch.from_numpy(
                            np.isin(
                                beam_seq_table[divm][:, :, t - divm - 1].cpu().
                                numpy(), self.bad_endings_ix)).reshape(-1),
                                 0] = float('-inf')
                    # suppress UNK tokens in the decoding
                    if suppress_UNK and hasattr(
                            self,
                            'vocab') and self.vocab[str(logprobs.size(1) -
                                                        1)] == 'UNK':
                        logprobs[:, logprobs.size(1) -
                                 1] = logprobs[:, logprobs.size(1) - 1] - 1000
                    # diversity is added here
                    # the function directly modifies the logprobs values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    logprobs, unaug_logprobs = add_diversity(
                        beam_seq_table, logprobs, t, divm, diversity_lambda,
                        bdash)

                    # infer new beams
                    beam_seq_table[divm],\
                    beam_seq_logprobs_table[divm],\
                    beam_logprobs_sum_table[divm],\
                    state_table[divm] = beam_step(logprobs,
                                                unaug_logprobs,
                                                bdash,
                                                t-divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for b in range(batch_size):
                        is_end = beam_seq_table[divm][b, :, t - divm] == 0
                        assert beam_seq_table[divm].shape[-1] == t - divm + 1
                        if t == self.seq_length + divm - 1:
                            is_end.fill_(1)
                        for vix in range(bdash):
                            if is_end[vix]:
                                final_beam = {
                                    'seq':
                                    beam_seq_table[divm][b, vix].clone(),
                                    'logps':
                                    beam_seq_logprobs_table[divm][b,
                                                                  vix].clone(),
                                    'unaug_p':
                                    beam_seq_logprobs_table[divm][
                                        b, vix].sum().item(),
                                    'p':
                                    beam_logprobs_sum_table[divm][b,
                                                                  vix].item()
                                }
                                final_beam['p'] = length_penalty(
                                    t - divm + 1, final_beam['p'])
                                done_beams_table[b][divm].append(final_beam)
                        beam_logprobs_sum_table[divm][b, is_end] -= 1000

                    # move the current group one step forward in time

                    it = beam_seq_table[divm][:, :, t - divm].reshape(-1)
                    logprobs_table[divm], state_table[
                        divm] = self.get_logprobs_state(
                            it.cuda(), *(args[divm] + [state_table[divm]]))
                    logprobs_table[divm] = F.log_softmax(logprobs_table[divm] /
                                                         temperature,
                                                         dim=-1)

        # all beams are sorted by their log-probabilities
        done_beams_table = [[
            sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash]
            for i in range(group_size)
        ] for b in range(batch_size)]
        done_beams = [sum(_, []) for _ in done_beams_table]
        return done_beams
コード例 #3
0
    def beam_search(self, init_state, init_logprobs, *args, **kwargs):

        opt = kwargs['opt']
        beam_size = opt.get('beam_size', 10)
        max_seqtree_length = opt.get('max_seqtree_length', 40)
        temperature = opt.get('temperature', 1)
        length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
        suppress_EOB_factor = opt.get('suppress_EOB_factor', 1)
        # assert suppress_EOB_factor > 1

        batch_size = init_logprobs.size(0)
        device = init_logprobs.device

        beam_seq_table = torch.LongTensor(batch_size, beam_size, 0).to(device)
        beam_parent_idx_table = torch.LongTensor(batch_size, beam_size,
                                                 max_seqtree_length).to(device)
        beam_parent_idx_table.fill_(0)
        beam_hidden_states_table = torch.FloatTensor(batch_size * beam_size,
                                                     max_seqtree_length,
                                                     self.rnn_size).to(device)
        beam_cell_states_table = torch.FloatTensor(batch_size * beam_size,
                                                   max_seqtree_length,
                                                   self.rnn_size).to(device)

        # init state
        # init_state `(batch_size, rnn_size)` -> `(batch_size*beam_size, rnn_size)`
        beam_hidden_states_table[:,
                                 0, :] = init_state[0].unsqueeze(dim=1).repeat(
                                     1, beam_size,
                                     1).view(batch_size * beam_size, -1)
        beam_cell_states_table[:,
                               0, :] = init_state[1].unsqueeze(dim=1).repeat(
                                   1, beam_size,
                                   1).view(batch_size * beam_size, -1)

        beam_seq_logprobs_table = torch.FloatTensor(
            batch_size, beam_size, 0, self.vocab_size + 1).to(device)
        beam_logprobs_sum_table = torch.zeros(batch_size, beam_size).to(device)
        logprobs = init_logprobs

        # generation finished utils
        counter_table = torch.LongTensor(batch_size, beam_size).to(device)
        counter_table.fill_(1)
        seqLen_table = torch.LongTensor(batch_size, beam_size).to(device)
        seqLen_table.fill_(0)
        all_finished_table = torch.BoolTensor(batch_size, beam_size).to(device)
        all_finished_table.fill_(0)

        done_beams_table = [[] for _ in range(batch_size)]

        for i in range(1, max_seqtree_length):
            if suppress_EOB_factor > 1:
                logprobs[:, self.
                         vocab_size] = logprobs[:, self.
                                                vocab_size] * suppress_EOB_factor
            logprobs[:, 0] = logprobs[:, 0] - 1000

            beam_seq_table, \
            beam_parent_idx_table, \
            beam_seq_logprobs_table, \
            beam_logprobs_sum_table, \
            (beam_hidden_states_table, \
            beam_cell_states_table), \
            counter_table, \
            seqLen_table, \
            all_finished_table = self.beam_step(logprobs,
                                            beam_size,
                                            i-1,
                                            beam_seq_table,
                                            beam_parent_idx_table,
                                            beam_seq_logprobs_table,
                                            beam_logprobs_sum_table,
                                            (beam_hidden_states_table,
                                            beam_cell_states_table),
                                            counter_table,
                                            seqLen_table,
                                            all_finished_table)

            for b in range(batch_size):
                is_end = all_finished_table[b, :]
                if i == max_seqtree_length - 1:
                    is_end.fill_(1)
                for vix in range(beam_size):
                    if is_end[vix]:
                        final_beam = {
                            'seq': beam_seq_table[b, vix].clone(),
                            'seq_idx': beam_parent_idx_table[b, vix].clone(),
                            'seqLen': seqLen_table[b, vix].clone(),
                            'logps': beam_seq_logprobs_table[b, vix].clone(),
                            'unaug_p':
                            beam_seq_logprobs_table[b, vix].sum().item(),
                            'p': beam_logprobs_sum_table[b, vix].item(),
                            'counter': counter_table[b, vix].item()
                        }
                        final_beam['p'] = length_penalty(
                            (final_beam['seq'] !=
                             self.vocab_size).sum().item(), final_beam['p'])
                        # print(final_beam['seq'].size(), final_beam['seqLen'])
                        done_beams_table[b].append(final_beam)
                beam_logprobs_sum_table[b, is_end] -= 1000

            # move the current group one step forward in time
            seqtree = beam_seq_table.view(batch_size * beam_size, -1)
            parent_idx = beam_parent_idx_table.view(batch_size * beam_size,
                                                    max_seqtree_length)
            p_it = torch.gather(seqtree,
                                dim=1,
                                index=parent_idx[:, i].clone().unsqueeze(1))
            p_it = p_it.squeeze(dim=1)
            p_xt = self.embed(p_it)

            hidden_states = beam_hidden_states_table.view(
                batch_size * beam_size, max_seqtree_length, self.rnn_size)
            cell_states = beam_cell_states_table.view(batch_size * beam_size,
                                                      max_seqtree_length,
                                                      self.rnn_size)

            p_idx = parent_idx[:, i].clone()
            p_idx = p_idx.unsqueeze(1).unsqueeze(1).expand(
                batch_size * beam_size, 1, self.hidden_size)
            p_hidden_state = torch.gather(hidden_states, dim=1,
                                          index=p_idx).squeeze(dim=1)
            p_cell_state = torch.gather(cell_states, dim=1,
                                        index=p_idx).squeeze(dim=1)

            p_state = p_hidden_state, p_cell_state

            if i % 3 == 1:
                s_xt = self.init_input(batch_size * beam_size)
                s_state = self.init_hidden(batch_size * beam_size)
            else:
                s_it = seqtree[:, i - 1].clone()
                s_xt = self.embed(s_it)
                s_hidden_state = hidden_states[:, i - 1].clone()
                s_cell_state = cell_states[:, i - 1].clone()
                s_state = s_hidden_state, s_cell_state

            logprobs, _state = self.get_logprobs_state(p_xt, s_xt, p_state,
                                                       s_state, *args)
            # logprobs = logprobs.view(batch_size, beam_size, self.vocab_size+1)
            logprobs = F.log_softmax(logprobs, dim=-1)
            # beam_hidden_states_table[:,:,i,:] = state[0].view(-1, self.rnn_size)
            # beam_cell_states_table[:,:,i,:] = state[1].view(-1, self.rnn_size)
            beam_hidden_states_table[:, i, :] = _state[0]
            beam_cell_states_table[:, i, :] = _state[1]

        # all beams are sorted by their log-probabilities
        done_beams_table = [
            sorted(done_beams_table[b], key=lambda x: -x['p'])
            for b in range(batch_size)
        ]
        # done_beams_table = [sorted(done_beams_table[b], key=lambda x: -x['p'])[:beam_size] for b in range(batch_size)]
        # done_beams = [sum(_, []) for _ in done_beams_table]
        return done_beams_table
コード例 #4
0
ファイル: denseCaptionModel.py プロジェクト: cici-ai-club/3M
    def beam_search(self, init_state, init_logprobs, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda,
                          bdash):
            local_time = t - divm
            unaug_logprobsf = logprobsf.clone()
            for prev_choice in range(divm):
                prev_decisions = beam_seq_table[prev_choice][local_time]
                for sub_beam in range(bdash):
                    for prev_labels in range(bdash):
                        logprobsf[sub_beam][
                            prev_decisions[prev_labels]] = logprobsf[sub_beam][
                                prev_decisions[prev_labels]] - diversity_lambda
            return unaug_logprobsf

        # does one step of classical beam search

        def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq,
                      beam_seq_logprobs, beam_logprobs_sum, state):
            #INPUTS:
            #logprobsf: probabilities augmented after diversity
            #beam_size: obvious
            #t        : time instant
            #beam_seq : tensor contanining the beams
            #beam_seq_logprobs: tensor contanining the beam logprobs
            #beam_logprobs_sum: tensor contanining joint logprobs
            #OUPUTS:
            #beam_seq : tensor containing the word indices of the decoded captions
            #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            #beam_logprobs_sum : joint log-probability of each beam
            ys, ix = torch.sort(logprobsf, 1, True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols):  # for each column (word, essentially)
                for q in range(rows):  # for each beam expansion
                    #compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q, c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    local_unaug_logprob = unaug_logprobsf[q, ix[q, c]]
                    candidates.append({
                        'c': ix[q, c],
                        'q': q,
                        'p': candidate_logprob,
                        'r': local_unaug_logprob
                    })
            candidates = sorted(candidates, key=lambda x: -x['p'])

            new_state = [_.clone() for _ in state]
            #beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
                #we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                #fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:,
                                                                        v['q']]
                #rearrange recurrent states
                for state_ix in range(len(new_state)):
                    #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v[
                        'q']]  # dimension one is time step
                #append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c']  # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r']  # the raw logprob here
                beam_logprobs_sum[vix] = v[
                    'p']  # the new (sum) logprob along this beam
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates

        # Start diverse_beam_search
        opt = kwargs['opt']
        temperature = opt.get(
            'temperature',
            1)  # This should not affect beam search, but will affect dbs
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 1)
        remove_bad_endings = opt.get('remove_bad_endings', 1)
        block_trigrams = opt.get('block_trigrams', 1)
        opt['length_penalty'] = 'avg_0'
        length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
        bdash = beam_size // group_size  # beam per group

        # INITIALIZATIONS
        beam_seq_table = [
            torch.LongTensor(self.seq_length, bdash).zero_()
            for _ in range(group_size)
        ]
        beam_seq_logprobs_table = [
            torch.FloatTensor(self.seq_length, bdash).zero_()
            for _ in range(group_size)
        ]
        beam_logprobs_sum_table = [
            torch.zeros(bdash) for _ in range(group_size)
        ]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[] for _ in range(group_size)]
        # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
        state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
        logprobs_table = list(init_logprobs.chunk(group_size, 0))
        # END INIT
        # Chunk elements in the args
        args = list(args)
        if self.__class__.__name__ == 'AttEnsemble':
            args = [[
                _.chunk(group_size) if _ is not None else [None] * group_size
                for _ in args_
            ] for args_ in args]  # arg_name, model_name, group_name
            args = [[[args[j][i][k] for i in range(len(self.models))]
                     for j in range(len(args))] for k in range(group_size)
                    ]  # group_name, arg_name, model_name
        else:
            args = [
                _.chunk(group_size) if _ is not None else [None] * group_size
                for _ in args
            ]
            args = [[args[i][j] for i in range(len(args))]
                    for j in range(group_size)]
        trigrams = []
        for t in range(self.seq_length + group_size - 1):
            for divm in range(group_size):
                if t >= divm and t <= self.seq_length + divm - 1:
                    # add diversity
                    logprobsf = logprobs_table[divm].data.float()
                    # suppress previous word
                    if decoding_constraint and t - divm > 0:
                        logprobsf.scatter_(
                            1, beam_seq_table[divm][t - divm -
                                                    1].unsqueeze(1).cuda(),
                            -10e20)
                        #if t-divm>=2:
                        #logprobsf.scatter_(1, beam_seq_table[divm][t-divm-2].unsqueeze(1).cuda(), -10e20)
                    if remove_bad_endings and t - divm > 0:
                        logprobsf[torch.from_numpy(
                            np.isin(
                                beam_seq_table[divm][t - divm - 1].cpu().numpy(
                                ), self.bad_endings_ix).astype(np.bool)),
                                  0] = -10e20
                    # suppress UNK tokens in the decoding
                    logprobsf[:, logprobsf.size(1) -
                              1] = logprobsf[:, logprobsf.size(1) - 1] - 10e20
                    # diversity is added here
                    # the function directly modifies the logprobsf values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    if block_trigrams and t - divm >= 3:
                        # Store trigram generated at last step
                        prev_two_batch = beam_seq_table[divm][
                            t - 3:t - 1]  #time*beam_size
                        for i in range(bdash):  # = seq.size(0)
                            prev_two = (prev_two_batch[0][i].item(),
                                        prev_two_batch[1][i].item())
                            current = beam_seq_table[divm][t - 1][i]
                            if t == 3:  # initialize
                                trigrams.append({
                                    prev_two: [current]
                                })  # {LongTensor: list containing 1 int}
                            elif t > 3:
                                if prev_two in trigrams[i]:  # add to list
                                    trigrams[i][prev_two].append(current)
                                else:  # create list
                                    trigrams[i][prev_two] = [current]
                            # Block used trigrams at next step
                        prev_two_batch = beam_seq_table[divm][t - 2:t]
                        mask = torch.zeros(logprobsf.size(),
                                           requires_grad=False).cuda(
                                           )  # batch_size x vocab_size
                        for i in range(bdash):
                            prev_two = (prev_two_batch[0][i].item(),
                                        prev_two_batch[1][i].item())
                            if prev_two in trigrams[i]:
                                for j in trigrams[i][prev_two]:
                                    mask[i, j] += 1

                        alpha = 10e20  # = 4
                        logprobsf = logprobsf + (
                            mask * -0.693 * alpha
                        )  # ln(1/2) * alpha (alpha -> infty works best)

                    unaug_logprobsf = add_diversity(beam_seq_table, logprobsf,
                                                    t, divm, diversity_lambda,
                                                    bdash)

                    # infer new beams
                    beam_seq_table[divm],\
                    beam_seq_logprobs_table[divm],\
                    beam_logprobs_sum_table[divm],\
                    state_table[divm],\
                    candidates_divm = beam_step(logprobsf,
                                                unaug_logprobsf,
                                                bdash,
                                                t-divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for vix in range(bdash):
                        if beam_seq_table[divm][
                                t - divm,
                                vix] == 0 or t == self.seq_length + divm - 1:
                            final_beam = {
                                'seq':
                                beam_seq_table[divm][:, vix].clone(),
                                'logps':
                                beam_seq_logprobs_table[divm][:, vix].clone(),
                                'unaug_p':
                                beam_seq_logprobs_table[divm]
                                [:, vix].sum().item(),
                                'p':
                                beam_logprobs_sum_table[divm][vix].item()
                            }
                            final_beam['p'] = length_penalty(
                                t - divm + 1, final_beam['p'])
                            done_beams_table[divm].append(final_beam)
                            # don't continue beams from finished sequences
                            beam_logprobs_sum_table[divm][vix] = -1000

                    # move the current group one step forward in time

                    it = beam_seq_table[divm][t - divm]
                    logprobs_table[divm], state_table[
                        divm] = self.get_logprobs_state(
                            it.cuda(), *(args[divm] + [state_table[divm]]))
                    logprobs_table[divm] = F.log_softmax(logprobs_table[divm] /
                                                         temperature,
                                                         dim=-1)

        # all beams are sorted by their log-probabilities
        done_beams_table = [
            sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash]
            for i in range(group_size)
        ]
        done_beams = functools.reduce(lambda a, b: a + b, done_beams_table)
        return done_beams
コード例 #5
0
    def beam_search(self, init_state, init_logprobs, init_aleatorics, init_epistemics, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
            local_time = t - divm
            unaug_logprobsf = logprobsf.clone()
            for prev_choice in range(divm):
                prev_decisions = beam_seq_table[prev_choice][local_time]
                for sub_beam in range(bdash):
                    for prev_labels in range(bdash):
                        logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
            return unaug_logprobsf

        # does one step of classical beam search
        def beam_step(logprobsf, unaug_logprobsf, aleatorics, epistemics, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, beam_seq_al, beam_seq_ep, state):
            #INPUTS:
            #logprobsf: probabilities augmented after diversity (beam_size, vocab_size)
            #aleatorics: aleatoric uncertainties evaluated at current step (beam_size,)
            #epistemics: epistemic uncertainties evaluated at current step (beam_size,)            
            #beam_size: obvious
            #t        : time instant
            #beam_seq : tensor contanining the beams
            #beam_seq_logprobs: tensor contanining the beam logprobs
            #beam_logprobs_sum: tensor contanining joint logprobs
            #beam_seq_al: tensor containing the aleatorics uncertainties of the candidates
            #beam_seq_ep: tensor containing the epistemics uncertainties of the candidates
            #OUPUTS:
            #beam_seq : tensor containing the word indices of the decoded captions
            #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            #beam_logprobs_sum : joint log-probability of each beam

            ys, ix = torch.sort(logprobsf, 1, True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols): # for each column (word, essentially)
                for q in range(rows): # for each beam expansion
                    #compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q, c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
                    candidates.append({'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': unaug_logprobsf[q]})
            candidates = sorted(candidates,  key=lambda x: -x['p'])
            
            new_state = [_.clone() for _ in state]
            #beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
            #we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
                beam_seq_al_prev = beam_seq_al[:t].clone()
                beam_seq_ep_prev = beam_seq_ep[:t].clone()                
            for vix in range(beam_size):
                v = candidates[vix]
                #fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
                    beam_seq_al[:t, vix] = beam_seq_al_prev[:, v['q']]
                    beam_seq_ep[:t, vix] = beam_seq_ep_prev[:, v['q']]                    
                #rearrange recurrent states
                for state_ix in range(len(new_state)):
                #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
                #append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c'] # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
                beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
                beam_seq_al[t, vix] = aleatorics[v['q']]
                beam_seq_ep[t, vix] = epistemics[v['q']]                
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, beam_seq_al, beam_seq_ep, state, candidates

        # Start diverse_beam_search
        opt = kwargs['opt']
        temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        uncertainty_lambda = opt.get('uncertainty_lambda', 0)
        decoding_constraint = opt.get('decoding_constraint', 0)
        remove_bad_endings = opt.get('remove_bad_endings', 0)
        suppress_UNK = opt.get('suppress_UNK', 0)
        length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
        bdash = beam_size // group_size # beam per group

        # INITIALIZATIONS
        beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
        beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
        beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
        beam_seq_aleatorics_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
        beam_seq_epistemics_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]        
        done_beams_table = [[] for _ in range(group_size)]
        # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
        state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))  
        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        logprobs_table = list(init_logprobs.chunk(group_size, 0))
        # [(beam_size,)]
        aleatorics_table = list(init_aleatorics.chunk(group_size, 0))
        epistemics_table = list(init_epistemics.chunk(group_size, 0))        
        # END INITn

        # Chunk elements in the args
        args = list(args)
        if self.__class__.__name__ == 'AttEnsemble':
            args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
            args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
        else:
            args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
            args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]

        for t in range(self.seq_length + group_size - 1):
            for divm in range(group_size): 
                if t >= divm and t <= self.seq_length + divm - 1:
                    # add diversity
                    logprobsf = logprobs_table[divm].float()
                    # suppress previous word
                    if decoding_constraint and t-divm > 0:
                        logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf'))
                    if remove_bad_endings and t-divm > 0:
                        logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
                    # suppress UNK tokens in the decoding
                    if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
                        logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000  
                    # diversity is added here
                    # the function directly modifies the logprobsf values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash)
                    
                    # get current uncertainties
                    aleatorics = aleatorics_table[divm]
                    epistemics = epistemics_table[divm]

                    # add uncertainty
                    logprobsf = logprobsf - uncertainty_lambda * epistemics.unsqueeze(-1)
                    
                    # infer new beams
                    beam_seq_table[divm],\
                    beam_seq_logprobs_table[divm],\
                    beam_logprobs_sum_table[divm],\
                    beam_seq_aleatorics_table[divm],\
                    beam_seq_epistemics_table[divm],\
                    state_table[divm],\
                    candidates_divm = beam_step(logprobsf,
                                                unaug_logprobsf,
                                                aleatorics,
                                                epistemics,
                                                bdash,
                                                t-divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                beam_seq_aleatorics_table[divm],
                                                beam_seq_epistemics_table[divm],
                                                state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for vix in range(bdash):
                        if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1:
                            final_beam = {
                                'seq': beam_seq_table[divm][:, vix].clone(),
                                'aleatorics': beam_seq_aleatorics_table[divm][:, vix].clone(),
                                'epistemics': beam_seq_epistemics_table[divm][:, vix].clone(),                                
                                'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
                                'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
                                'p': beam_logprobs_sum_table[divm][vix].item()
                            }
                            final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
                            done_beams_table[divm].append(final_beam)
                            # don't continue beams from finished sequences
                            beam_logprobs_sum_table[divm][vix] = -1000

                    # move the current group one step forward in time
                    it = beam_seq_table[divm][t-divm]
                    logprobs_table[divm], state_table[divm], aleatorics_table[divm], epistemics_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]]))
                    logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)

        # all beams are sorted by their log-probabilities
        done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
        done_beams = sum(done_beams_table, [])
        return done_beams