def loglikelihood_rolling(self, requests): # TODO: Implement caching once we've confirmed the perplexity implementation # TODO: automatic batch size detection for vectorization loglikelihoods = [] for (string, ) in tqdm(requests): rolling_token_windows = list( map( utils.make_disjoint_window, utils.get_rolling_token_windows( token_list=self.tok_encode(string), prefix_token=self.eot_token_id, max_seq_len=self.max_length, context_len=1, ), )) rolling_token_windows = [(None, ) + x for x in rolling_token_windows] # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for # that string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True) # discard is_greedy string_nll = [x[0] for x in string_nll] string_nll = sum(string_nll) loglikelihoods.append(string_nll) return loglikelihoods
def test_get_rolling_token_windows_v2(): gold = [ ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [10, 11, 12]), ([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [13, 14, 15]), ([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [16, 17, 18]), ([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [19, 20, 21]), ([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [22, 23, 24]), ([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [25, 26, 27]), ([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [28, 29, 30]), ([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [31, 32, 33]), ] x = list(range(34)) generator = get_rolling_token_windows( token_list=x, prefix_token=-100, max_seq_len=10, context_len=8, ) pred_length = 0 output = [] for input_tokens, pred_tokens in generator: output.append((input_tokens, pred_tokens)) pred_length += len(pred_tokens) assert pred_length == len(x) assert gold == output
def test_get_rolling_token_windows_v5(): gold = [ ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ( [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], ), ( [19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], ), ] x = list(range(30)) generator = get_rolling_token_windows( token_list=x, prefix_token=-100, max_seq_len=10, context_len=1, ) pred_length = 0 output = [] for input_tokens, pred_tokens in generator: output.append((input_tokens, pred_tokens)) pred_length += len(pred_tokens) assert pred_length == len(x) assert gold == output
def test_get_rolling_token_windows_empty(): generator = get_rolling_token_windows( token_list=[], prefix_token=-100, max_seq_len=2, context_len=1, ) n = 0 for _ in generator: n += 1 assert n == 0
def test_get_rolling_token_windows_v6(): gold = [ ([-100, 0], [0, 1]), ([1, 2], [2, 3]), ([3, 4], [4, 5]), ([5, 6], [6, 7]), ([6, 7], [8]), ] x = list(range(9)) generator = get_rolling_token_windows( token_list=x, prefix_token=-100, max_seq_len=2, context_len=1, ) pred_length = 0 output = [] for input_tokens, pred_tokens in generator: output.append((input_tokens, pred_tokens)) pred_length += len(pred_tokens) assert pred_length == len(x) assert gold == output
def loglikelihood_rolling(self, requests): # TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing loglikelihoods = [] for string, in tqdm(requests): encoded = self.tokenizer.encode_plus(string)["input_ids"] rolling_token_windows = utils.get_rolling_token_windows( token_list=encoded, prefix_token=self.end_of_text_token_id, max_seq_len=self.MAX_LENGTH, context_len=1, ) string_loglikelihoods = [] for input_tokens, pred_tokens in rolling_token_windows: block_output = self.get_token_logprobs( input_tokens=input_tokens, pred_tokens=pred_tokens, ) string_loglikelihoods.append(block_output["logprobs"]) string_loglikelihoods = np.concatenate(string_loglikelihoods).sum() loglikelihoods.append(string_loglikelihoods) return loglikelihoods