def start(self, start_with_nothing): state = LMState() prefix = torch.LongTensor([[self.dictionary.eos()]]) incremental_state = {} if self.save_incremental else None with torch.no_grad(): res = self.model(prefix.cuda(), incremental_state=incremental_state) probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) if incremental_state is not None: incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state) self.states[state] = FairseqLMState( prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy() ) self.stateq.append(state) return state
def score(self, state: LMState, token_index: int, no_cache: bool = False): """ Evaluate language model based on the current lm state and new word Parameters: ----------- state: current lm state token_index: index of the word (can be lexicon index then you should store inside LM the mapping between indices of lexicon and lm, or lm index of a word) Returns: -------- (LMState, float): pair of (new state, score for the current word) """ curr_state = self.states[state] def trim_cache(targ_size): while len(self.stateq) > targ_size: rem_k = self.stateq.popleft() rem_st = self.states[rem_k] rem_st = FairseqLMState(rem_st.prefix, None, None) self.states[rem_k] = rem_st if curr_state.probs is None: new_incremental_state = (curr_state.incremental_state.copy() if curr_state.incremental_state is not None else None) with torch.no_grad(): if new_incremental_state is not None: new_incremental_state = apply_to_sample( lambda x: x.cuda(), new_incremental_state) elif self.save_incremental: new_incremental_state = {} res = self.model( torch.from_numpy(curr_state.prefix).cuda(), incremental_state=new_incremental_state, ) probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) if new_incremental_state is not None: new_incremental_state = apply_to_sample( lambda x: x.cpu(), new_incremental_state) curr_state = FairseqLMState(curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()) if not no_cache: self.states[state] = curr_state self.stateq.append(state) score = curr_state.probs[token_index].item() trim_cache(self.max_cache) outstate = state.child(token_index) if outstate not in self.states and not no_cache: prefix = np.concatenate( [curr_state.prefix, torch.LongTensor([[token_index]])], -1) incr_state = curr_state.incremental_state self.states[outstate] = FairseqLMState(prefix, incr_state, None) if token_index == self.unk: score = float("-inf") return outstate, score