def beam_search(self, init_state, init_logprobs, *args, **kwargs): # function computes the similarity score to be augmented def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): local_time = t - divm unaug_logprobsf = logprobsf.clone() for prev_choice in range(divm): prev_decisions = beam_seq_table[prev_choice][local_time] for sub_beam in range(bdash): for prev_labels in range(bdash): logprobsf[sub_beam][ prev_decisions[prev_labels]] = logprobsf[sub_beam][ prev_decisions[prev_labels]] - diversity_lambda return unaug_logprobsf # does one step of classical beam search def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): #INPUTS: #logprobsf: probabilities augmented after diversity #beam_size: obvious #t : time instant #beam_seq : tensor contanining the beams #beam_seq_logprobs: tensor contanining the beam logprobs #beam_logprobs_sum: tensor contanining joint logprobs #OUPUTS: #beam_seq : tensor containing the word indices of the decoded captions #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq #beam_logprobs_sum : joint log-probability of each beam ys, ix = torch.sort(logprobsf, 1, True) candidates = [] cols = min(beam_size, ys.size(1)) rows = beam_size if t == 0: rows = 1 for c in range(cols): # for each column (word, essentially) for q in range(rows): # for each beam expansion #compute logprob of expanding beam q with word in (sorted) position c local_logprob = ys[q, c].item() candidate_logprob = beam_logprobs_sum[q] + local_logprob local_unaug_logprob = unaug_logprobsf[q, ix[q, c]] candidates.append({ 'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': local_unaug_logprob }) candidates = sorted(candidates, key=lambda x: -x['p']) new_state = [_.clone() for _ in state] #beam_seq_prev, beam_seq_logprobs_prev if t >= 1: #we''ll need these as reference when we fork beams around beam_seq_prev = beam_seq[:t].clone() beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() for vix in range(beam_size): v = candidates[vix] #fork beam index q into index vix if t >= 1: beam_seq[:t, vix] = beam_seq_prev[:, v['q']] beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] #rearrange recurrent states for state_ix in range(len(new_state)): # copy over state in previous beam q to new beam at vix new_state[state_ix][:, vix] = state[state_ix][:, v[ 'q']] # dimension one is time step #append new end terminal at the end of this beam beam_seq[t, vix] = v['c'] # c'th word is the continuation beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here beam_logprobs_sum[vix] = v[ 'p'] # the new (sum) logprob along this beam state = new_state return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates # Start diverse_beam_search opt = kwargs['opt'] beam_size = opt.get('beam_size', 10) group_size = opt.get('group_size', 1) diversity_lambda = opt.get('diversity_lambda', 0.5) decoding_constraint = opt.get('decoding_constraint', 0) max_ppl = opt.get('max_ppl', 0) length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) bdash = beam_size // group_size # beam per group # INITIALIZATIONS beam_seq_table = [ torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size) ] beam_seq_logprobs_table = [ torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size) ] beam_logprobs_sum_table = [ torch.zeros(bdash) for _ in range(group_size) ] # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) done_beams_table = [[] for _ in range(group_size)] state_table = [ list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2) ] logprobs_table = list(init_logprobs.chunk(group_size, 0)) # END INIT # Chunk elements in the args args = list(args) args = [ _.chunk(group_size) if _ is not None else [None] * group_size for _ in args ] args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] for t in range(self.seq_length + group_size - 1): for divm in range(group_size): if t >= divm and t <= self.seq_length + divm - 1: # add diversity logprobsf = logprobs_table[divm].data.float() # suppress previous word if decoding_constraint and t - divm > 0: logprobsf.scatter_( 1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), float('-inf')) # suppress UNK tokens in the decoding logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 1000 # diversity is added here # the function directly modifies the logprobsf values and hence, we need to return # the unaugmented ones for sorting the candidates in the end. # for historical # reasons :-) unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash) # infer new beams beam_seq_table[divm],\ beam_seq_logprobs_table[divm],\ beam_logprobs_sum_table[divm],\ state_table[divm],\ candidates_divm = beam_step(logprobsf, unaug_logprobsf, bdash, t-divm, beam_seq_table[divm], beam_seq_logprobs_table[divm], beam_logprobs_sum_table[divm], state_table[divm]) # if time's up... or if end token is reached then copy beams for vix in range(bdash): if beam_seq_table[divm][ t - divm, vix] == 0 or t == self.seq_length + divm - 1: final_beam = { 'seq': beam_seq_table[divm][:, vix].clone(), 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), 'unaug_p': beam_seq_logprobs_table[divm] [:, vix].sum().item(), 'p': beam_logprobs_sum_table[divm][vix].item() } final_beam['p'] = length_penalty( t - divm + 1, final_beam['p']) # if max_ppl: # final_beam['p'] = final_beam['p'] / (t-divm+1) done_beams_table[divm].append(final_beam) # don't continue beams from finished sequences beam_logprobs_sum_table[divm][vix] = -1000 # move the current group one step forward in time it = beam_seq_table[divm][t - divm] logprobs_table[divm], state_table[ divm] = self.get_logprobs_state( it.cuda(), *(args[divm] + [state_table[divm]])) # all beams are sorted by their log-probabilities done_beams_table = [ sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size) ] done_beams = reduce(lambda a, b: a + b, done_beams_table) return done_beams
def beam_search(self, init_state, init_logprobs, *args, **kwargs): # function computes the similarity score to be augmented def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): local_time = t - divm unaug_logprobs = logprobs.clone() batch_size = beam_seq_table[0].shape[0] if divm > 0: change = logprobs.new_zeros(batch_size, logprobs.shape[-1]) for prev_choice in range(divm): prev_decisions = beam_seq_table[ prev_choice][:, :, local_time] # Nxb for prev_labels in range(bdash): change.scatter_add_( 1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1)) if local_time == 0: logprobs = logprobs - change * diversity_lambda else: logprobs = logprobs - self.repeat_tensor( bdash, change) * diversity_lambda return logprobs, unaug_logprobs # does one step of classical beam search def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): #INPUTS: #logprobs: probabilities augmented after diversity N*bxV #beam_size: obvious #t : time instant #beam_seq : tensor contanining the beams #beam_seq_logprobs: tensor contanining the beam logprobs #beam_logprobs_sum: tensor contanining joint logprobs #OUPUTS: #beam_seq : tensor containing the word indices of the decoded captions Nxbxl #beam_seq_logprobs : log-probability of each decision made, NxbxlxV #beam_logprobs_sum : joint log-probability of each beam Nxb batch_size = beam_logprobs_sum.shape[0] vocab_size = logprobs.shape[-1] logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV if t == 0: assert logprobs.shape[1] == 1 beam_logprobs_sum = beam_logprobs_sum[:, :1] candidate_logprobs = beam_logprobs_sum.unsqueeze( -1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV ys, ix = torch.sort( candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) ys, ix = ys[:, :beam_size], ix[:, :beam_size] beam_ix = ix // vocab_size # Nxb which beam selected_ix = ix % vocab_size # Nxb # which world state_ix = ( beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams if t > 0: # gather according to beam_ix assert (beam_seq.gather( 1, beam_ix.unsqueeze(-1). expand_as(beam_seq)) == beam_seq.reshape( -1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() beam_seq = beam_seq.gather( 1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) beam_seq_logprobs = beam_seq_logprobs.gather( 1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as( beam_seq_logprobs)) beam_seq = torch.cat( [beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ logprobs.reshape(batch_size, -1).gather(1, ix) assert (beam_logprobs_sum == ys).all() _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape( batch_size, -1, vocab_size) beam_logprobs = unaug_logprobs.reshape( batch_size, -1, vocab_size).gather( 1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV assert (_tmp_beam_logprobs == beam_logprobs).all() beam_seq_logprobs = torch.cat([ beam_seq_logprobs, beam_logprobs.reshape(batch_size, -1, 1, vocab_size) ], 2) new_state = [None for _ in state] for _ix in range(len(new_state)): # copy over state in previous beam q to new beam at vix new_state[_ix] = state[_ix][:, state_ix] state = new_state return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state # Start diverse_beam_search opt = kwargs['opt'] temperature = opt.get( 'temperature', 1) # This should not affect beam search, but will affect dbs beam_size = opt.get('beam_size', 10) group_size = opt.get('group_size', 1) diversity_lambda = opt.get('diversity_lambda', 0.5) decoding_constraint = opt.get('decoding_constraint', 0) remove_bad_endings = opt.get('remove_bad_endings', 0) suppress_UNK = opt.get('suppress_UNK', 0) length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) bdash = beam_size // group_size # beam per group batch_size = init_logprobs.shape[0] device = init_logprobs.device # INITIALIZATIONS beam_seq_table = [ torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size) ] beam_seq_logprobs_table = [ torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size) ] beam_logprobs_sum_table = [ torch.zeros(batch_size, bdash).to(device) for _ in range(group_size) ] # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)] # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] # state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state])) state_table = [[_.clone() for _ in init_state] for _ in range(group_size)] # logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0)) logprobs_table = [init_logprobs.clone() for _ in range(group_size)] # END INIT # Chunk elements in the args args = list(args) args = utils.split_tensors( group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x... if self.__class__.__name__ == 'AttEnsemble': args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size) ] # group_name, arg_name, model_name else: args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] for t in range(self.seq_length + group_size - 1): for divm in range(group_size): if t >= divm and t <= self.seq_length + divm - 1: # add diversity logprobs = logprobs_table[divm] # suppress previous word if decoding_constraint and t - divm > 0: logprobs.scatter_( 1, beam_seq_table[divm][:, :, t - divm - 1].reshape( -1, 1).to(device), float('-inf')) if remove_bad_endings and t - divm > 0: logprobs[torch.from_numpy( np.isin( beam_seq_table[divm][:, :, t - divm - 1].cpu(). numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf') # suppress UNK tokens in the decoding if suppress_UNK and hasattr( self, 'vocab') and self.vocab[str(logprobs.size(1) - 1)] == 'UNK': logprobs[:, logprobs.size(1) - 1] = logprobs[:, logprobs.size(1) - 1] - 1000 # diversity is added here # the function directly modifies the logprobs values and hence, we need to return # the unaugmented ones for sorting the candidates in the end. # for historical # reasons :-) logprobs, unaug_logprobs = add_diversity( beam_seq_table, logprobs, t, divm, diversity_lambda, bdash) # infer new beams beam_seq_table[divm],\ beam_seq_logprobs_table[divm],\ beam_logprobs_sum_table[divm],\ state_table[divm] = beam_step(logprobs, unaug_logprobs, bdash, t-divm, beam_seq_table[divm], beam_seq_logprobs_table[divm], beam_logprobs_sum_table[divm], state_table[divm]) # if time's up... or if end token is reached then copy beams for b in range(batch_size): is_end = beam_seq_table[divm][b, :, t - divm] == 0 assert beam_seq_table[divm].shape[-1] == t - divm + 1 if t == self.seq_length + divm - 1: is_end.fill_(1) for vix in range(bdash): if is_end[vix]: final_beam = { 'seq': beam_seq_table[divm][b, vix].clone(), 'logps': beam_seq_logprobs_table[divm][b, vix].clone(), 'unaug_p': beam_seq_logprobs_table[divm][ b, vix].sum().item(), 'p': beam_logprobs_sum_table[divm][b, vix].item() } final_beam['p'] = length_penalty( t - divm + 1, final_beam['p']) done_beams_table[b][divm].append(final_beam) beam_logprobs_sum_table[divm][b, is_end] -= 1000 # move the current group one step forward in time it = beam_seq_table[divm][:, :, t - divm].reshape(-1) logprobs_table[divm], state_table[ divm] = self.get_logprobs_state( it.cuda(), *(args[divm] + [state_table[divm]])) logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) # all beams are sorted by their log-probabilities done_beams_table = [[ sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size) ] for b in range(batch_size)] done_beams = [sum(_, []) for _ in done_beams_table] return done_beams
def beam_search(self, init_state, init_logprobs, *args, **kwargs): opt = kwargs['opt'] beam_size = opt.get('beam_size', 10) max_seqtree_length = opt.get('max_seqtree_length', 40) temperature = opt.get('temperature', 1) length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) suppress_EOB_factor = opt.get('suppress_EOB_factor', 1) # assert suppress_EOB_factor > 1 batch_size = init_logprobs.size(0) device = init_logprobs.device beam_seq_table = torch.LongTensor(batch_size, beam_size, 0).to(device) beam_parent_idx_table = torch.LongTensor(batch_size, beam_size, max_seqtree_length).to(device) beam_parent_idx_table.fill_(0) beam_hidden_states_table = torch.FloatTensor(batch_size * beam_size, max_seqtree_length, self.rnn_size).to(device) beam_cell_states_table = torch.FloatTensor(batch_size * beam_size, max_seqtree_length, self.rnn_size).to(device) # init state # init_state `(batch_size, rnn_size)` -> `(batch_size*beam_size, rnn_size)` beam_hidden_states_table[:, 0, :] = init_state[0].unsqueeze(dim=1).repeat( 1, beam_size, 1).view(batch_size * beam_size, -1) beam_cell_states_table[:, 0, :] = init_state[1].unsqueeze(dim=1).repeat( 1, beam_size, 1).view(batch_size * beam_size, -1) beam_seq_logprobs_table = torch.FloatTensor( batch_size, beam_size, 0, self.vocab_size + 1).to(device) beam_logprobs_sum_table = torch.zeros(batch_size, beam_size).to(device) logprobs = init_logprobs # generation finished utils counter_table = torch.LongTensor(batch_size, beam_size).to(device) counter_table.fill_(1) seqLen_table = torch.LongTensor(batch_size, beam_size).to(device) seqLen_table.fill_(0) all_finished_table = torch.BoolTensor(batch_size, beam_size).to(device) all_finished_table.fill_(0) done_beams_table = [[] for _ in range(batch_size)] for i in range(1, max_seqtree_length): if suppress_EOB_factor > 1: logprobs[:, self. vocab_size] = logprobs[:, self. vocab_size] * suppress_EOB_factor logprobs[:, 0] = logprobs[:, 0] - 1000 beam_seq_table, \ beam_parent_idx_table, \ beam_seq_logprobs_table, \ beam_logprobs_sum_table, \ (beam_hidden_states_table, \ beam_cell_states_table), \ counter_table, \ seqLen_table, \ all_finished_table = self.beam_step(logprobs, beam_size, i-1, beam_seq_table, beam_parent_idx_table, beam_seq_logprobs_table, beam_logprobs_sum_table, (beam_hidden_states_table, beam_cell_states_table), counter_table, seqLen_table, all_finished_table) for b in range(batch_size): is_end = all_finished_table[b, :] if i == max_seqtree_length - 1: is_end.fill_(1) for vix in range(beam_size): if is_end[vix]: final_beam = { 'seq': beam_seq_table[b, vix].clone(), 'seq_idx': beam_parent_idx_table[b, vix].clone(), 'seqLen': seqLen_table[b, vix].clone(), 'logps': beam_seq_logprobs_table[b, vix].clone(), 'unaug_p': beam_seq_logprobs_table[b, vix].sum().item(), 'p': beam_logprobs_sum_table[b, vix].item(), 'counter': counter_table[b, vix].item() } final_beam['p'] = length_penalty( (final_beam['seq'] != self.vocab_size).sum().item(), final_beam['p']) # print(final_beam['seq'].size(), final_beam['seqLen']) done_beams_table[b].append(final_beam) beam_logprobs_sum_table[b, is_end] -= 1000 # move the current group one step forward in time seqtree = beam_seq_table.view(batch_size * beam_size, -1) parent_idx = beam_parent_idx_table.view(batch_size * beam_size, max_seqtree_length) p_it = torch.gather(seqtree, dim=1, index=parent_idx[:, i].clone().unsqueeze(1)) p_it = p_it.squeeze(dim=1) p_xt = self.embed(p_it) hidden_states = beam_hidden_states_table.view( batch_size * beam_size, max_seqtree_length, self.rnn_size) cell_states = beam_cell_states_table.view(batch_size * beam_size, max_seqtree_length, self.rnn_size) p_idx = parent_idx[:, i].clone() p_idx = p_idx.unsqueeze(1).unsqueeze(1).expand( batch_size * beam_size, 1, self.hidden_size) p_hidden_state = torch.gather(hidden_states, dim=1, index=p_idx).squeeze(dim=1) p_cell_state = torch.gather(cell_states, dim=1, index=p_idx).squeeze(dim=1) p_state = p_hidden_state, p_cell_state if i % 3 == 1: s_xt = self.init_input(batch_size * beam_size) s_state = self.init_hidden(batch_size * beam_size) else: s_it = seqtree[:, i - 1].clone() s_xt = self.embed(s_it) s_hidden_state = hidden_states[:, i - 1].clone() s_cell_state = cell_states[:, i - 1].clone() s_state = s_hidden_state, s_cell_state logprobs, _state = self.get_logprobs_state(p_xt, s_xt, p_state, s_state, *args) # logprobs = logprobs.view(batch_size, beam_size, self.vocab_size+1) logprobs = F.log_softmax(logprobs, dim=-1) # beam_hidden_states_table[:,:,i,:] = state[0].view(-1, self.rnn_size) # beam_cell_states_table[:,:,i,:] = state[1].view(-1, self.rnn_size) beam_hidden_states_table[:, i, :] = _state[0] beam_cell_states_table[:, i, :] = _state[1] # all beams are sorted by their log-probabilities done_beams_table = [ sorted(done_beams_table[b], key=lambda x: -x['p']) for b in range(batch_size) ] # done_beams_table = [sorted(done_beams_table[b], key=lambda x: -x['p'])[:beam_size] for b in range(batch_size)] # done_beams = [sum(_, []) for _ in done_beams_table] return done_beams_table
def beam_search(self, init_state, init_logprobs, *args, **kwargs): # function computes the similarity score to be augmented def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): local_time = t - divm unaug_logprobsf = logprobsf.clone() for prev_choice in range(divm): prev_decisions = beam_seq_table[prev_choice][local_time] for sub_beam in range(bdash): for prev_labels in range(bdash): logprobsf[sub_beam][ prev_decisions[prev_labels]] = logprobsf[sub_beam][ prev_decisions[prev_labels]] - diversity_lambda return unaug_logprobsf # does one step of classical beam search def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): #INPUTS: #logprobsf: probabilities augmented after diversity #beam_size: obvious #t : time instant #beam_seq : tensor contanining the beams #beam_seq_logprobs: tensor contanining the beam logprobs #beam_logprobs_sum: tensor contanining joint logprobs #OUPUTS: #beam_seq : tensor containing the word indices of the decoded captions #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq #beam_logprobs_sum : joint log-probability of each beam ys, ix = torch.sort(logprobsf, 1, True) candidates = [] cols = min(beam_size, ys.size(1)) rows = beam_size if t == 0: rows = 1 for c in range(cols): # for each column (word, essentially) for q in range(rows): # for each beam expansion #compute logprob of expanding beam q with word in (sorted) position c local_logprob = ys[q, c].item() candidate_logprob = beam_logprobs_sum[q] + local_logprob local_unaug_logprob = unaug_logprobsf[q, ix[q, c]] candidates.append({ 'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': local_unaug_logprob }) candidates = sorted(candidates, key=lambda x: -x['p']) new_state = [_.clone() for _ in state] #beam_seq_prev, beam_seq_logprobs_prev if t >= 1: #we''ll need these as reference when we fork beams around beam_seq_prev = beam_seq[:t].clone() beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() for vix in range(beam_size): v = candidates[vix] #fork beam index q into index vix if t >= 1: beam_seq[:t, vix] = beam_seq_prev[:, v['q']] beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] #rearrange recurrent states for state_ix in range(len(new_state)): # copy over state in previous beam q to new beam at vix new_state[state_ix][:, vix] = state[state_ix][:, v[ 'q']] # dimension one is time step #append new end terminal at the end of this beam beam_seq[t, vix] = v['c'] # c'th word is the continuation beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here beam_logprobs_sum[vix] = v[ 'p'] # the new (sum) logprob along this beam state = new_state return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates # Start diverse_beam_search opt = kwargs['opt'] temperature = opt.get( 'temperature', 1) # This should not affect beam search, but will affect dbs beam_size = opt.get('beam_size', 10) group_size = opt.get('group_size', 1) diversity_lambda = opt.get('diversity_lambda', 0.5) decoding_constraint = opt.get('decoding_constraint', 1) remove_bad_endings = opt.get('remove_bad_endings', 1) block_trigrams = opt.get('block_trigrams', 1) opt['length_penalty'] = 'avg_0' length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) bdash = beam_size // group_size # beam per group # INITIALIZATIONS beam_seq_table = [ torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size) ] beam_seq_logprobs_table = [ torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size) ] beam_logprobs_sum_table = [ torch.zeros(bdash) for _ in range(group_size) ] # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) done_beams_table = [[] for _ in range(group_size)] # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) logprobs_table = list(init_logprobs.chunk(group_size, 0)) # END INIT # Chunk elements in the args args = list(args) if self.__class__.__name__ == 'AttEnsemble': args = [[ _.chunk(group_size) if _ is not None else [None] * group_size for _ in args_ ] for args_ in args] # arg_name, model_name, group_name args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size) ] # group_name, arg_name, model_name else: args = [ _.chunk(group_size) if _ is not None else [None] * group_size for _ in args ] args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] trigrams = [] for t in range(self.seq_length + group_size - 1): for divm in range(group_size): if t >= divm and t <= self.seq_length + divm - 1: # add diversity logprobsf = logprobs_table[divm].data.float() # suppress previous word if decoding_constraint and t - divm > 0: logprobsf.scatter_( 1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), -10e20) #if t-divm>=2: #logprobsf.scatter_(1, beam_seq_table[divm][t-divm-2].unsqueeze(1).cuda(), -10e20) if remove_bad_endings and t - divm > 0: logprobsf[torch.from_numpy( np.isin( beam_seq_table[divm][t - divm - 1].cpu().numpy( ), self.bad_endings_ix).astype(np.bool)), 0] = -10e20 # suppress UNK tokens in the decoding logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 10e20 # diversity is added here # the function directly modifies the logprobsf values and hence, we need to return # the unaugmented ones for sorting the candidates in the end. # for historical # reasons :-) if block_trigrams and t - divm >= 3: # Store trigram generated at last step prev_two_batch = beam_seq_table[divm][ t - 3:t - 1] #time*beam_size for i in range(bdash): # = seq.size(0) prev_two = (prev_two_batch[0][i].item(), prev_two_batch[1][i].item()) current = beam_seq_table[divm][t - 1][i] if t == 3: # initialize trigrams.append({ prev_two: [current] }) # {LongTensor: list containing 1 int} elif t > 3: if prev_two in trigrams[i]: # add to list trigrams[i][prev_two].append(current) else: # create list trigrams[i][prev_two] = [current] # Block used trigrams at next step prev_two_batch = beam_seq_table[divm][t - 2:t] mask = torch.zeros(logprobsf.size(), requires_grad=False).cuda( ) # batch_size x vocab_size for i in range(bdash): prev_two = (prev_two_batch[0][i].item(), prev_two_batch[1][i].item()) if prev_two in trigrams[i]: for j in trigrams[i][prev_two]: mask[i, j] += 1 alpha = 10e20 # = 4 logprobsf = logprobsf + ( mask * -0.693 * alpha ) # ln(1/2) * alpha (alpha -> infty works best) unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash) # infer new beams beam_seq_table[divm],\ beam_seq_logprobs_table[divm],\ beam_logprobs_sum_table[divm],\ state_table[divm],\ candidates_divm = beam_step(logprobsf, unaug_logprobsf, bdash, t-divm, beam_seq_table[divm], beam_seq_logprobs_table[divm], beam_logprobs_sum_table[divm], state_table[divm]) # if time's up... or if end token is reached then copy beams for vix in range(bdash): if beam_seq_table[divm][ t - divm, vix] == 0 or t == self.seq_length + divm - 1: final_beam = { 'seq': beam_seq_table[divm][:, vix].clone(), 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), 'unaug_p': beam_seq_logprobs_table[divm] [:, vix].sum().item(), 'p': beam_logprobs_sum_table[divm][vix].item() } final_beam['p'] = length_penalty( t - divm + 1, final_beam['p']) done_beams_table[divm].append(final_beam) # don't continue beams from finished sequences beam_logprobs_sum_table[divm][vix] = -1000 # move the current group one step forward in time it = beam_seq_table[divm][t - divm] logprobs_table[divm], state_table[ divm] = self.get_logprobs_state( it.cuda(), *(args[divm] + [state_table[divm]])) logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) # all beams are sorted by their log-probabilities done_beams_table = [ sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size) ] done_beams = functools.reduce(lambda a, b: a + b, done_beams_table) return done_beams
def beam_search(self, init_state, init_logprobs, init_aleatorics, init_epistemics, *args, **kwargs): # function computes the similarity score to be augmented def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): local_time = t - divm unaug_logprobsf = logprobsf.clone() for prev_choice in range(divm): prev_decisions = beam_seq_table[prev_choice][local_time] for sub_beam in range(bdash): for prev_labels in range(bdash): logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda return unaug_logprobsf # does one step of classical beam search def beam_step(logprobsf, unaug_logprobsf, aleatorics, epistemics, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, beam_seq_al, beam_seq_ep, state): #INPUTS: #logprobsf: probabilities augmented after diversity (beam_size, vocab_size) #aleatorics: aleatoric uncertainties evaluated at current step (beam_size,) #epistemics: epistemic uncertainties evaluated at current step (beam_size,) #beam_size: obvious #t : time instant #beam_seq : tensor contanining the beams #beam_seq_logprobs: tensor contanining the beam logprobs #beam_logprobs_sum: tensor contanining joint logprobs #beam_seq_al: tensor containing the aleatorics uncertainties of the candidates #beam_seq_ep: tensor containing the epistemics uncertainties of the candidates #OUPUTS: #beam_seq : tensor containing the word indices of the decoded captions #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq #beam_logprobs_sum : joint log-probability of each beam ys, ix = torch.sort(logprobsf, 1, True) candidates = [] cols = min(beam_size, ys.size(1)) rows = beam_size if t == 0: rows = 1 for c in range(cols): # for each column (word, essentially) for q in range(rows): # for each beam expansion #compute logprob of expanding beam q with word in (sorted) position c local_logprob = ys[q, c].item() candidate_logprob = beam_logprobs_sum[q] + local_logprob # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] candidates.append({'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': unaug_logprobsf[q]}) candidates = sorted(candidates, key=lambda x: -x['p']) new_state = [_.clone() for _ in state] #beam_seq_prev, beam_seq_logprobs_prev if t >= 1: #we''ll need these as reference when we fork beams around beam_seq_prev = beam_seq[:t].clone() beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() beam_seq_al_prev = beam_seq_al[:t].clone() beam_seq_ep_prev = beam_seq_ep[:t].clone() for vix in range(beam_size): v = candidates[vix] #fork beam index q into index vix if t >= 1: beam_seq[:t, vix] = beam_seq_prev[:, v['q']] beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] beam_seq_al[:t, vix] = beam_seq_al_prev[:, v['q']] beam_seq_ep[:t, vix] = beam_seq_ep_prev[:, v['q']] #rearrange recurrent states for state_ix in range(len(new_state)): # copy over state in previous beam q to new beam at vix new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step #append new end terminal at the end of this beam beam_seq[t, vix] = v['c'] # c'th word is the continuation beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam beam_seq_al[t, vix] = aleatorics[v['q']] beam_seq_ep[t, vix] = epistemics[v['q']] state = new_state return beam_seq, beam_seq_logprobs, beam_logprobs_sum, beam_seq_al, beam_seq_ep, state, candidates # Start diverse_beam_search opt = kwargs['opt'] temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs beam_size = opt.get('beam_size', 10) group_size = opt.get('group_size', 1) diversity_lambda = opt.get('diversity_lambda', 0.5) uncertainty_lambda = opt.get('uncertainty_lambda', 0) decoding_constraint = opt.get('decoding_constraint', 0) remove_bad_endings = opt.get('remove_bad_endings', 0) suppress_UNK = opt.get('suppress_UNK', 0) length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) bdash = beam_size // group_size # beam per group # INITIALIZATIONS beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)] beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] beam_seq_aleatorics_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] beam_seq_epistemics_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] done_beams_table = [[] for _ in range(group_size)] # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) logprobs_table = list(init_logprobs.chunk(group_size, 0)) # [(beam_size,)] aleatorics_table = list(init_aleatorics.chunk(group_size, 0)) epistemics_table = list(init_epistemics.chunk(group_size, 0)) # END INITn # Chunk elements in the args args = list(args) if self.__class__.__name__ == 'AttEnsemble': args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name else: args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args] args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] for t in range(self.seq_length + group_size - 1): for divm in range(group_size): if t >= divm and t <= self.seq_length + divm - 1: # add diversity logprobsf = logprobs_table[divm].float() # suppress previous word if decoding_constraint and t-divm > 0: logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf')) if remove_bad_endings and t-divm > 0: logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf') # suppress UNK tokens in the decoding if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK': logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 # diversity is added here # the function directly modifies the logprobsf values and hence, we need to return # the unaugmented ones for sorting the candidates in the end. # for historical # reasons :-) unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash) # get current uncertainties aleatorics = aleatorics_table[divm] epistemics = epistemics_table[divm] # add uncertainty logprobsf = logprobsf - uncertainty_lambda * epistemics.unsqueeze(-1) # infer new beams beam_seq_table[divm],\ beam_seq_logprobs_table[divm],\ beam_logprobs_sum_table[divm],\ beam_seq_aleatorics_table[divm],\ beam_seq_epistemics_table[divm],\ state_table[divm],\ candidates_divm = beam_step(logprobsf, unaug_logprobsf, aleatorics, epistemics, bdash, t-divm, beam_seq_table[divm], beam_seq_logprobs_table[divm], beam_logprobs_sum_table[divm], beam_seq_aleatorics_table[divm], beam_seq_epistemics_table[divm], state_table[divm]) # if time's up... or if end token is reached then copy beams for vix in range(bdash): if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1: final_beam = { 'seq': beam_seq_table[divm][:, vix].clone(), 'aleatorics': beam_seq_aleatorics_table[divm][:, vix].clone(), 'epistemics': beam_seq_epistemics_table[divm][:, vix].clone(), 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), 'p': beam_logprobs_sum_table[divm][vix].item() } final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) done_beams_table[divm].append(final_beam) # don't continue beams from finished sequences beam_logprobs_sum_table[divm][vix] = -1000 # move the current group one step forward in time it = beam_seq_table[divm][t-divm] logprobs_table[divm], state_table[divm], aleatorics_table[divm], epistemics_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]])) logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) # all beams are sorted by their log-probabilities done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] done_beams = sum(done_beams_table, []) return done_beams