def listener(self, state, utterance, depth): # base case listener is either neurally trained, or inferred from neural s0, given the state's current prior on images # world = RSA_World(target=0,speaker=0,rationality=0) # image_prior = self.listener(state=state,utterance=utterance,depth=depth-1) # rationality_prior = np.asarray([0.3,0.7]) world_prior = state.world_priors[state.timestep - 1] # print("world prior",np.exp(world_prior)) # if state.timestep < 4: # print("world priors",np.exp(state.world_priors[:4])) # print("timestep",state.timestep) # if depth==0: # else: world_prior = self.listener(state=state,utterance=utterance,depth=0) # print(world_prior.shape) # I could write: itertools product axes scores = np.zeros((world_prior.shape)) for n_tuple in itertools.product( *[list(range(x)) for x in world_prior.shape]): # print(n_tuple) # print(world_prior.shape) # for j in range(self.number_of_images): # for i in range(len(rationality_prior)): # world.target=j world = RSA_World(target=n_tuple[state.dim["image"]], rationality=n_tuple[state.dim["rationality"]], speaker=n_tuple[state.dim["speaker"]]) # world.set_values(n_tuple) # world.rationality=rationality_prior[i] # NOTE THAT NOT DEPTH-1 HERE out = self.speaker(state=state, world=world, depth=depth) # out = np.squeeze(out) # print(out,depth) scores[n_tuple] = out[utterance] scores = scores * state.listener_rationality world_posterior = ( scores + world_prior) - scipy.misc.logsumexp(scores + world_prior) # print("world posterior listener complex shape",world_posterior.shape) return world_posterior
def listener_simple(self, state, utterance, depth): # base case listener is either neurally trained, or inferred from neural s0, given the state's current prior on images # world = RSA_World(target=0,speaker=0,rationality=0) # image_prior = self.listener(state=state,utterance=utterance,depth=depth-1) # rationality_prior = np.asarray([0.3,0.7]) world_prior = state.world_priors[state.timestep - 1] assert world_prior.shape == (2, 1, 1) print("world prior", np.exp(world_prior)) # world_prior = np.log(np.asarray([0.5,0.5])) # if depth==0: # else: world_prior = self.listener(state=state,utterance=utterance,depth=0) # print(world_prior.shape) # I could write: itertools product axes scores = np.zeros((2, 1, 1)) for i in range(2): # print(n_tuple) # print(world_prior.shape) # for j in range(self.number_of_images): # for i in range(len(rationality_prior)): # world.target=j world = RSA_World(target=i, rationality=0, speaker=0) # world.set_values(n_tuple) # world.rationality=rationality_prior[i] # NOTE THAT NOT DEPTH-1 HERE out = self.speaker(state=state, world=world, depth=depth) # out = np.squeeze(out) # print(out,depth) scores[i] = out[utterance] scores = scores * state.listener_rationality world_posterior = ( scores + world_prior) - scipy.misc.logsumexp(scores + world_prior) # print("world posterior listener simple shape",world_posterior.shape) return world_posterior
def listener(self, state, utterance, depth): '''probability distribution over referents given utterance''' # the listener generates a conditional probability distribution # over possible target images given a utterance # # method: # 1) iterate through possible target images # 2) get probability distribution over vocabulary # given a specific target image from speaker, # select probability value for the given utterance # 3) return probability distribution for every image # that utterance is chosen # get state from previous time step, initialize score array world_prior = state.world_priors[state.timestep - 1] scores = np.zeros((world_prior.shape)) # iterate through all possible images for n_tuple in itertools.product( *[list(range(x)) for x in world_prior.shape]): # point to the image i currently in focus world = RSA_World(target=n_tuple[state.dim["image"]], rationality=n_tuple[state.dim["rationality"]], speaker=n_tuple[state.dim["speaker"]]) # note THAT NOT DEPTH-1 HERE # get probability distribution over possible next tokens # from speaker given the target image i out = self.speaker(state=state, world=world, depth=depth) # save probability value for the given utterance given i scores[n_tuple] = out[utterance] # apply listener rationality (1 by default) scores = scores * state.listener_rationality world_posterior = (scores + world_prior ) - scipy.special.logsumexp(scores + world_prior) return world_posterior
def listener_simple(self, state, utterance, depth): # base case listener is either neurally trained, # or inferred from neural s0, given the state's current prior on images world_prior = state.world_priors[state.timestep - 1] assert world_prior.shape == (2, 1, 1) print("world prior", np.exp(world_prior)) # I could write: itertools product axes scores = np.zeros((2, 1, 1)) for i in range(2): world = RSA_World(target=i, rationality=0, speaker=0) # note THAT NOT DEPTH-1 HERE out = self.speaker(state=state, world=world, depth=depth) scores[i] = out[utterance] scores = scores * state.listener_rationality world_posterior = (scores + world_prior) - \ scipy.special.logsumexp(scores + world_prior) return world_posterior
def ana_greedy(rsa, initial_world_prior, speaker_rationality, speaker, target, pass_prior=True, listener_rationality=1.0, depth=0, start_from=[]): """ speaker_rationality,listener_rationality: see speaker and listener code for what they do: control strength of conditioning depth: the number of levels of RSA: depth 0 uses listeral speaker to unroll, depth n uses speaker n to unroll, and listener n to update at each step start_from: a partial caption you start the unrolling from img_prior: a prior on the world to start with """ # this RSA passes along a state: see rsa_state state = RSA_State(initial_world_prior, listener_rationality=listener_rationality) # state.image_priors[:]=img_prior context_sentence = ['^'] + start_from state.context_sentence = context_sentence world = RSA_World(target=target, rationality=speaker_rationality, speaker=speaker) probs = [] for timestep in tqdm(range(len(start_from) + 1, max_sentence_length)): state.timestep = timestep s = rsa.speaker(state=state, world=world, depth=depth) # print("S:",s) # print(s) segment = np.argmax(s) # print("s",rsa.idx2seg[segment]) prob = np.max(s) probs.append(prob) if pass_prior: l = rsa.listener(state=state, utterance=segment, depth=depth) state.world_priors[state.timestep] = l state.context_sentence += [rsa.idx2seg[segment]] if (rsa.idx2seg[segment] == stop_token[rsa.seg_type]): break summed_probs = np.sum(np.asarray(probs)) world_posterior = state.world_priors[:state.timestep + 1][:5] return [("".join(state.context_sentence), summed_probs)]
def ana_beam( rsa, initial_world_prior, speaker_rationality, target, speaker, pass_prior=True, listener_rationality=1.0, depth=0, start_from=[], beam_width=len(sym_set), cut_rate=1, decay_rate=0.0, beam_decay=0, ): """ speaker_rationality,listener_rationality: see speaker and listener code for what they do: control strength of conditioning depth: the number of levels of RSA: depth 0 uses listeral speaker to unroll, depth n uses speaker n to unroll, and listener n to update at each step start_from: a partial caption you start the unrolling from img_prior: a prior on the world to start with which_image: which of the images in the prior should be targeted? beam width: width beam is cut down to every cut_rate iterations of the unrolling cut_rate: how often beam is cut down to beam_width beam_decay: amount by which beam_width is lessened after each iteration decay_rate: a multiplier that makes later decisions in the unrolling matter less: 0.0 does no decay. negative decay makes start matter more """ state = RSA_State(initial_world_prior, listener_rationality=listener_rationality) # state.image_priors[:]=img_prior context_sentence = start_from state.context_sentence = context_sentence world = RSA_World(target=target, rationality=speaker_rationality, speaker=speaker) context_sentence = start_from state.context_sentence = context_sentence sent_worldprior_prob = [(state.context_sentence, state.world_priors, 0.0)] final_sentences = [] toc = time.time() for timestep in tqdm(range(len(start_from) + 1, max_sentence_length)): state.timestep = timestep new_sent_worldprior_prob = [] for sent, worldpriors, old_prob in sent_worldprior_prob: state.world_priors = worldpriors if state.timestep > 1: state.context_sentence = sent[:-1] seg = sent[-1] if depth > 0 and pass_prior: l = rsa.listener(state=state, utterance=rsa.seg2idx[seg], depth=depth) state.world_priors[state.timestep - 1] = copy.deepcopy(l) state.context_sentence = sent # out = rsa.speaker(state=state,img_idx=which_image,depth=depth) s = rsa.speaker(state=state, world=world, depth=depth) for seg, prob in enumerate(np.squeeze(s)): new_sentence = copy.deepcopy(sent) # conditional to deal with captions longer than max sentence length # if state.timestep<max_sentence_length+1: new_sentence += [rsa.idx2seg[seg]] # else: new_sentence = np.expand_dims(np.expand_dims(np.concat([np.squeeze(new_sentence)[:-1],[seg]],axis=0),0),-1) state.context_sentence = new_sentence new_prob = ( prob * (1 / math.pow(state.timestep, decay_rate))) + old_prob # print("beam listener",rsa.word2ord[seg], l) new_sent_worldprior_prob.append( (new_sentence, worldpriors, new_prob)) rsa.flush_cache() sent_worldprior_prob = sorted(new_sent_worldprior_prob, key=lambda x: x[-1], reverse=True) if state.timestep % cut_rate == 0: # cut down to size sent_worldprior_prob = sent_worldprior_prob[:beam_width] new_sent_worldprior_prob = [] for sent, worldprior, prob in sent_worldprior_prob: # print("".join(sent),np.exp(prob)) # print(state.timestep) if sent[-1] == stop_token[rsa.seg_type]: final_sentence = copy.deepcopy(sent) final_sentences.append((final_sentence, prob)) # print("REMOVED SENTENCE") else: new_triple = copy.deepcopy((sent, worldprior, prob)) new_sent_worldprior_prob.append(new_triple) sent_worldprior_prob = new_sent_worldprior_prob if len(final_sentences) >= 50: # # print("beam unroll time",tic-toc) # # print(state.image_priors[:]) sentences = sorted(final_sentences, key=lambda x: x[-1], reverse=True) output = [] for i, (sent, prob) in enumerate(sentences): output.append(("".join(sent), prob)) return output # # print(sentences) # for i,(sent,prob) in enumerate(sentences): # output.append(("".join([rsa.idx2word[idx] for idx in np.squeeze(sent)]),prob)) # return output # return "COMPLETE" # return "".join([rsa.idx2word[idx] for idx in np.squeeze(final_sentences[0])]) if beam_decay < beam_width: beam_width -= beam_decay # print("decayed beam width by "+str(beam_decay)+"; beam_width now: "+str(beam_width)) else: sentences = sorted(final_sentences, key=lambda x: x[-1], reverse=True) output = [] # print(sentences) for i, (sent, prob) in enumerate(sentences): output.append(("".join(sent), prob)) return output
def ana_mixed_beam(rsa, initial_world_prior, speaker_rationality, target, speaker, pass_prior=False, listener_rationality=1.0, start_from=[], beam_width=5, cut_rate=1, decay_rate=0.0, beam_decay=0, no_progress_bar=False): """ speaker_rationality,listener_rationality: see speaker and listener code for what they do: control strength of conditioning depth: the number of levels of RSA: depth 0 uses listeral speaker to unroll, depth n uses speaker n to unroll, and listener n to update at each step start_from: a partial caption you start the unrolling from img_prior: a prior on the world to start with which_image: which of the images in the prior should be targeted? beam width: width beam is cut down to every cut_rate iterations of the unrolling cut_rate: how often beam is cut down to beam_width beam_decay: amount by which beam_width is lessened after each iteration decay_rate: a multiplier that makes later decisions in the unrolling matter less: 0.0 does no decay. negative decay makes start matter more """ print("Hello this is mixed beam search") # initialize state object # (containing initial - uniform - world priors for t0) # default: start from empty token list state = RSA_State(initial_world_prior, listener_rationality=listener_rationality) context_sentence = start_from state.context_sentence = context_sentence # intitialize world object world = RSA_World(target=target, rationality=speaker_rationality, speaker=speaker) context_sentence = start_from state.context_sentence = context_sentence # initial probability value sent_worldprior_prob = [(state.context_sentence, state.world_priors, 0.0)] final_sentences = [] toc = time.time() # iterate through individual time steps / tokens to be produced for timestep in tqdm(range(len(start_from) + 1, max_sentence_length), disable=no_progress_bar): #print("this is timestep",timestep) #print(len(sent_worldprior_prob)) state.timestep = timestep new_sent_worldprior_prob = [] # iterate through previous beams for sent, worldpriors, old_prob in sent_worldprior_prob: #print("\t",sent,old_prob) state.world_priors = worldpriors if state.timestep > 1: seg = sent[-1] if seg == ' ': depth = 1 else: depth = 0 else: depth = 1 #print("depth",depth) state.context_sentence = sent # get probability distribution over possible next tokens from speaker s = rsa.speaker(state=state, world=world, depth=depth) #print(s) top_segs = np.argsort(s)[::-1][:beam_width] # iterate through possible next tokens for seg in top_segs: prob = s[seg] # add new segment to existing sentence new_sentence = copy.deepcopy(sent) # conditional to deal with captions longer than max sentence length # if state.timestep<max_sentence_length+1: new_sentence += [rsa.idx2seg[seg]] # else: new_sentence = np.expand_dims(np.expand_dims(np.concat([np.squeeze(new_sentence)[:-1],[seg]],axis=0),0),-1) state.context_sentence = new_sentence # calculate new probability for resulting sentence new_prob = ( prob * (1 / math.pow(state.timestep, decay_rate))) + old_prob new_sent_worldprior_prob.append( (new_sentence, worldpriors, new_prob)) rsa.flush_cache() # sort possible next sentences by probability (descending) sent_worldprior_prob = sorted(new_sent_worldprior_prob, key=lambda x: x[-1], reverse=True) sent_worldprior_prob = sent_worldprior_prob[:beam_width] if len(sent_worldprior_prob) > 1: top_sent = sent_worldprior_prob[0][0] #print("top sent",top_sent) if top_sent[-1] == '<end>': output = [] print(top_sent) for sent, worldprior, prob in sent_worldprior_prob: # add sentences with stop token to final_sentences final_sentence = copy.deepcopy(sent) output.append(("".join(final_sentence), prob)) return output output = [] for sent, worldprior, prob in sent_worldprior_prob: # add sentences with stop token to final_sentences final_sentence = copy.deepcopy(sent) output.append(("".join(final_sentence), prob)) return output
def ana_greedy(rsa, initial_world_prior, speaker_rationality, speaker, target, pass_prior=False, listener_rationality=1.0, depth=0, start_from=[], start_token='<start>', end_token='<end>', no_progress_bar=False): """ speaker_rationality,listener_rationality: see speaker and listener code for what they do: control strength of conditioning depth: the number of levels of RSA: depth 0 uses listeral speaker to unroll, depth n uses speaker n to unroll, and listener n to update at each step start_from: a partial caption you start the unrolling from img_prior: a prior on the world to start with :Paper: To perform greedy unrolling [...] for either S0 or S1, we initialize the state as a partial caption pc0 consisting of only the start token, and a uniform prior over the images ip0. Then, for t > 0, we use our incremental speaker model S0 or S1 to generate a distribution over the subsequent character S t (u|w, ipt, pct), and add the character u with highest probability density to pct , giving us pct+1. We then run our listener model L1 on u, to obtain a distribution ipt+1 = Lt 1(w|u, ipt , pct ) over images that the L0 can use at the next timestep. """ # this RSA passes along a state: see rsa_state # contains initial world priors (default: uniform priors for all tokens) state = RSA_State(initial_world_prior, listener_rationality=listener_rationality) # initialize partial context caption with start symbol context_sentence = [start_token] + start_from state.context_sentence = context_sentence world = RSA_World(target=target, rationality=speaker_rationality, speaker=speaker) probs = [] # perform greedy unrolling for timestep in tqdm(range(len(start_from) + 1, max_sentence_length), disable=no_progress_bar): # update state.timestep state.timestep = timestep # get probability distribution over possible next tokens from (pragmatic) speaker s = rsa.speaker(state=state, world=world, depth=depth) # next token segment = np.argmax(s) # probability for next token prob = np.max(s) probs.append(prob) if pass_prior: # update world_priors with rsa.listener: # informed decision from rational speaker for future timesteps # (instead of uniform prior over possible targets) l = rsa.listener(state=state, utterance=segment, depth=depth) state.world_priors[state.timestep] = l # add next token to context_sentence state.context_sentence += [rsa.idx2seg[segment]] # break if stop token has been reached if (rsa.idx2seg[segment] == end_token): # stop_token[rsa.seg_type]): break summed_probs = np.sum(np.asarray(probs)) world_posterior = state.world_priors[:state.timestep + 1][:5] return [("".join(state.context_sentence), summed_probs)]