def greedy_until(self, requests, batch=False): self.model.module.inference_mode() res = [] def _collate(x): toks = self.tokenizer.encode(x[0]) return (len(toks), x[0]) reord = utils.Reorderer(requests, _collate) for context, until in tqdm(reord.get_reordered()): if isinstance(until, str): until = [until] stop_tokens = [self.tokenizer.encode(i) for i in until] cont = self.generate(text=context, stop_tokens=stop_tokens, recompute = self.neox_args.recompute) s = cont[0]['text'] or '' for term in until: s = s.split(term)[0] # partial caching self.cache_hook.add_partial("greedy_until", (context, until), s) res.append(s) self.model.module.train_mode() return reord.get_original(res)
def greedy_until(self, requests): # TODO: implement fully general `until` that handles untils that are # multiple tokens or that span multiple tokens correctly res = [] def _collate(x): toks = self.tokenizer.encode(x[0]) return (len(toks), x[0]) reord = utils.Reorderer(requests, _collate) for context, until in tqdm(reord.get_reordered()): if isinstance(until, str): until = [until] context_enc = torch.tensor([ self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:] ]).to(self.device) primary_until, = self.tokenizer.encode(until[0]) cont = self.gpt2.generate(context_enc, max_length=context_enc.shape[1] + self.MAX_GEN_TOKS, eos_token_id=primary_until, do_sample=False) s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:]) for term in until: s = s.split(term)[0] res.append(s) return reord.get_original(res)
def greedy_until(self, requests): if not requests: return [] res = [] def _collate(x): toks = self.tok_encode(x[0]) return len(toks), x[0] re_ord = utils.Reorderer(requests, _collate) def sameuntil_chunks(xs, size): ret = [] lastuntil = xs[0][1] for x in xs: if len(ret) >= size or x[1] != lastuntil: yield ret, lastuntil ret = [] lastuntil = x[1] ret.append(x) if ret: yield ret, lastuntil # todo: more intelligent batching for heterogeneous `until` for chunk, until in tqdm( list( sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))): inps = [] for context, _ in chunk: context_enc = self.tok_encode(context) inp = context_enc[-(self.max_length - self.max_gen_toks):] inps.append(inp) response = oa_completion( engine=self.engine, prompt=inps, max_tokens=self.max_gen_toks, temperature=0.0, logprobs=10, stop=until, ) for resp, (context, until_) in zip(response.choices, chunk): s = resp["text"] for term in until_: s = s.split(term)[0] # partial caching self.cache_hook.add_partial("greedy_until", (context, until_), s) res.append(s) return re_ord.get_original(res)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): res = [] def _collate(x): # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # we care about and so we need some kind of backup for when it isn't toks = x[1] + x[2] return -len(toks), tuple(toks) re_ord = utils.Reorderer(requests, _collate) for chunk in tqdm( list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm, ): inps = [] ctxlens = [] for cache_key, context_enc, continuation_enc in chunk: # max_length+1 because the API takes up to 2049 tokens, including the first context token inp = (context_enc + continuation_enc)[-(self.max_length + 1):] # TODO: the logic is much simpler if we just look at the length of continuation tokens ctxlen = len(context_enc) - max( 0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)) inps.append(inp) ctxlens.append(ctxlen) response = oa_completion( engine=self.engine, prompt=inps, echo=True, max_tokens=0, temperature=0.0, logprobs=10, ) for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip( response.choices, ctxlens, chunk): answer = get_result(resp, ctxlen) res.append(answer) # partial caching if cache_key is not None: self.cache_hook.add_partial("loglikelihood", cache_key, answer) return re_ord.get_original(res)
def _loglikelihood_tokens(self, requests): import openai res = [] def _collate(x): # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # we care about and so we need some kind of backup for when it isn't toks = x[1] + x[2] return (len(toks), tuple(toks)) reord = utils.Reorderer(requests, _collate) for chunk in tqdm( list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): inps = [] ctxlens = [] for cache_key, context_enc, continuation_enc in chunk: inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] ctxlen = len(context_enc) - max( 0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH) inps.append(inp) ctxlens.append(ctxlen) response = oa_completion( engine=self.engine, prompt=inps, echo=True, max_tokens=0, temperature=0., logprobs=10, ) for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip( response.choices, ctxlens, chunk): answer = get_result(resp, ctxlen) res.append(answer) # partial caching if cache_key is not None: self.cache_hook.add_partial("loglikelihood", cache_key, answer) return reord.get_original(res)
def greedy_until(self, requests): if not requests: return [] import openai res = [] def _collate(x): toks = self.tokenizer.encode(x[0]) return (len(toks), x[0]) reord = utils.Reorderer(requests, _collate) def sameuntil_chunks(xs, size): ret = [] lastuntil = xs[0][1] for x in xs: if len(ret) >= size or x[1] != lastuntil: yield ret, lastuntil ret = [] lastuntil = x[1] ret.append(x) if ret: yield ret, lastuntil # todo: more intelligent batching for heterogenous `until` for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): inps = [] for context, _ in chunk: context_enc = self.tokenizer.encode(context) inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):] inps.append(inp) response = oa_completion( engine=self.engine, prompt=inps, max_tokens=self.MAX_GEN_TOKS, temperature=0., logprobs=10, stop=until ) for resp in response.choices: s = resp['text'] for term in until: s = s.split(term)[0] res.append(s) return reord.get_original(res)
def loglikelihood(self, requests): # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] with torch.no_grad(): # TODO: vectorize properly # TODO: automatic batch size detection for vectorization def _collate(x): toks = self.tokenizer.encode(x[0] + x[1]) return (len(toks), x) reord = utils.Reorderer(requests, _collate) for context, continuation in tqdm(reord.get_reordered()): # when too long to fit in context, truncate from the left combined_toks = self.tokenizer.encode(context + continuation) if context == "": # end of text as context context_enc = [50256] else: context_enc = self.tokenizer.encode(context) continuation_enc = self.tokenizer.encode(continuation) inp = torch.tensor( [(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device) ctxlen = len(context_enc) - max( 0, len(context_enc) + len(continuation_enc) - self.max_length) cont_toks = inp[:, ctxlen:] # [batch, seq] logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] greedy_tokens = logits.argmax(dim=-1) max_equal = (greedy_tokens == cont_toks).all() last_token_slice = logits[:, -1, :].squeeze(0).tolist() logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( -1) # [batch, seq] res.append((float(logits.sum()), bool(max_equal))) return reord.get_original(res)
def _loglikelihood_tokens(self, requests): # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] with torch.no_grad(): # TODO: vectorize properly # TODO: automatic batch size detection for vectorization def _collate(x): toks = x[1] + x[2] return (len(toks), tuple(toks)) reord = utils.Reorderer(requests, _collate) for cache_key, context_enc, continuation_enc in tqdm( reord.get_reordered()): # when too long to fit in context, truncate from the left inp = torch.tensor( [(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device) ctxlen = len(context_enc) - max( 0, len(context_enc) + len(continuation_enc) - self.max_length) cont_toks = inp[:, ctxlen:] # [batch, seq] logits = F.log_softmax(self.gpt2(inp)[0][:, :, :50257], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] greedy_tokens = logits.argmax(dim=-1) max_equal = (greedy_tokens == cont_toks).all() last_token_slice = logits[:, -1, :].squeeze(0).tolist() logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( -1) # [batch, seq] answer = (float(logits.sum()), bool(max_equal)) # partial caching if cache_key is not None: self.cache_hook.add_partial("loglikelihood", cache_key, answer) res.append(answer) return reord.get_original(res)
def greedy_until(self, requests): # TODO: implement fully general `until` that handles untils that are # multiple tokens or that span multiple tokens correctly res = [] def _collate(x): toks = self.tokenizer.encode(x[0]) return (len(toks), x[0]) reord = utils.Reorderer(requests, _collate) for i, (context, until) in enumerate(tqdm(reord.get_reordered())): if isinstance(until, str): until = [until] context_enc = torch.tensor([self.tokenizer.encode(context.strip() + "<extra_id_0>.")[- self.max_length:]]).to(self.device) # primary_until, = self.tokenizer.encode(until[0]) cont = self.model.generate( context_enc, max_length=self.MAX_GEN_TOKS, eos_token_id=self.tokenizer.convert_tokens_to_ids("<extra_id_1>"), do_sample=False ) s = self.tokenizer.decode(cont[0].tolist()[2:-1]) if i < 100: print(f"Context: {self.tokenizer.decode(context_enc[0])}\nGeneration: {s}") if "<extra_id_1>" in s: s = s[:s.index("<extra_id_1>")] if "</s>" in s: s = s[:s.index("</s>")] for term in until: s = s.split(term)[0] if i < 100: print(f"Final S: {s}") # partial caching self.cache_hook.add_partial("greedy_until", (context, until), s) res.append(s) return reord.get_original(res)
def greedy_until(self, requests): """ Greedy until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks. the eval harness dispatches requests to the model, and the model does argmax generation, the results of which are returned to the eval harness to evaluate. TODO: batched / data parallel generation :param requests: Dictionary of requests containing the context (prompt) and 'until' - a token or list of stop tokens. """ self.model.module.inference_mode(use_cache=True) # tell model to cache kv pairs res = [] def _collate(x): toks = self.tokenizer.encode(x[0]) return (len(toks), x[0]) reord = utils.Reorderer(requests, _collate) for context, until in tqdm(reord.get_reordered(), "Running greedy generation"): if isinstance(until, str): until = [until] stop_tokens = [self.tokenizer.encode(i) for i in until] cont = self.generate( text=context, stop_tokens=stop_tokens, recompute=self.neox_args.recompute, ) if cont: s = cont[0]["text"] or "" else: s = "" for term in until: s = s.split(term)[0] # partial caching self.cache_hook.add_partial("greedy_until", (context, until), s) res.append(s) self.model.module.train_mode() # set back to train mode return reord.get_original(res)
def loglikelihood(self, requests): import openai res = [] def _collate(x): # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # we care about and so we need some kind of backup for when it isn't toks = self.tokenizer.encode(x[0] + x[1]) return (len(toks), self.tokenizer.decode(toks)) reord = utils.Reorderer(requests, _collate) for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): inps = [] ctxlens = [] for context, continuation in chunk: if context == "": # end of text as context context_enc = [50256] else: context_enc = self.tokenizer.encode(context) continuation_enc = self.tokenizer.encode(continuation) inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH) inps.append(inp) ctxlens.append(ctxlen) response = oa_completion( engine=self.engine, prompt=inps, echo=True, max_tokens=0, temperature=0., logprobs=10, ) for resp, ctxlen in zip(response.choices, ctxlens): res.append(get_result(resp, ctxlen)) return reord.get_original(res)
def greedy_until(self, requests): # TODO: implement fully general `until` that handles until that are # multiple tokens or that span multiple tokens correctly # TODO: extract to TokenizedLM? res = [] def _collate(x): toks = self.tok_encode(x[0]) return len(toks), x[0] re_ord = utils.Reorderer(requests, _collate) for context, until in tqdm(re_ord.get_reordered()): if isinstance(until, str): until = [until] (primary_until, ) = self.tok_encode(until[0]) context_enc = torch.tensor([ self.tok_encode(context)[self.max_gen_toks - self.max_length:] ]).to(self.device) cont = self._model_generate( context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until) s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:]) for term in until: s = s.split(term)[0] # partial caching self.cache_hook.add_partial("greedy_until", (context, until), s) res.append(s) return re_ord.get_original(res)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] def _collate(x): # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = x[1] + x[2] return -len(toks), tuple(toks) # TODO: automatic (variable) batch size detection for vectorization re_ord = utils.Reorderer(requests, _collate) for chunk in utils.chunks( tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size): inps = [] cont_toks_list = [] inplens = [] padding_length = None # because vectorizing is annoying, we first convert each (context, continuation) pair to padded # tensors, then we pack them together into a batch, call the model, and then pick it all apart # again because vectorizing is annoying for _, context_enc, continuation_enc in chunk: # sanity check assert len(context_enc) > 0 assert len(continuation_enc) > 0 assert len(continuation_enc) <= self.max_length # how this all works: # CTX CONT # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # gpt2 \ \ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # when too long to fit in context, truncate from the left inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1):][:-1], dtype=torch.long, ).to(self.device) (inplen, ) = inp.shape cont = continuation_enc # since in _collate we make sure length is descending, the longest is always the first one. padding_length = (padding_length if padding_length is not None else inplen) # pad length from seq to padding_length inp = torch.cat( [ inp, # [seq] torch.zeros(padding_length - inplen, dtype=torch.long).to( inp.device), # [padding_length - seq] ], dim=0, ) inps.append(inp.unsqueeze(0)) # [1, padding_length] cont_toks_list.append(cont) inplens.append(inplen) batched_inps = torch.cat(inps, dim=0) # [batch, padding_length multi_logits = F.log_softmax( self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab] for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( chunk, multi_logits, inps, inplens, cont_toks_list): # Slice to original seq length contlen = len(cont_toks) logits = logits[inplen - contlen:inplen].unsqueeze( 0) # [1, seq, vocab] # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) cont_toks = torch.tensor( cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() logits = torch.gather( logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) # partial caching if cache_key is not None: self.cache_hook.add_partial("loglikelihood", cache_key, answer) res.append(answer) return re_ord.get_original(res)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): """ In this method, the model doesn't do any generation, but just returns log likelihoods for the next token, which eval harness uses to evaluate. :param requests: Dictionary of requests containing the context and the expected continuation. :param disable_tqdm: If True, disable tqdm progress bar. """ self.model.module.inference_mode( use_cache=False ) # tell model to gather parallel outputs, but not cache key-value pairs disable_tqdm = disable_tqdm if self.is_main else True res = [] res_len = 0 # storing the result length for later with torch.no_grad(): def _collate(x): toks = x[1] + x[2] return (-len(toks), tuple(toks)) reord = utils.Reorderer(requests, _collate) for chunk in utils.chunks( tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size ): inps, contlens, inplens, padding_length = [], [], [], None for cache_key, context_enc, continuation_enc in chunk: # when too long to fit in context, truncate from the left inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], dtype=torch.long, ).to(self.device) (inplen,) = inp.shape cont = continuation_enc # since in _collate we make sure length is descending, the longest is always the first one. padding_length = ( padding_length if padding_length is not None else inplen ) # pad to length inp = torch.cat( [ inp, # [seq] torch.zeros(padding_length - inplen, dtype=torch.long).to( inp.device ), # [padding_length - seq] ], dim=0, ) inps.append(inp.unsqueeze(0)) contlens.append(cont) inplens.append(inplen) logits = self._model_call(torch.cat(inps, dim=0)) res_len += len(chunk) if logits is not None: multi_logits = F.log_softmax(logits, dim=-1) # [batch, seq, vocab] for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( chunk, multi_logits, inps, inplens, contlens ): contlen = len(cont_toks) logits = logits[inplen - contlen : inplen].unsqueeze( 0 ) # [1, seq, vocab] greedy_tokens = logits.argmax(dim=-1) # cont_toks :: [1, seq] cont_toks = ( torch.tensor(cont_toks, dtype=torch.long) .unsqueeze(0) .to(multi_logits.device) ) max_equal = (greedy_tokens == cont_toks).all() logits = torch.gather( logits, 2, cont_toks.unsqueeze(-1) ).squeeze( -1 ) # [1, seq] answer = (float(logits.sum()), bool(max_equal)) # partial caching if cache_key is not None: self.cache_hook.add_partial( "loglikelihood", cache_key, answer ) res.append(answer) # broadcast results to all ranks if self.is_pipe_parallel: src_rank = self.model.grid.stage_to_global(self.model.num_stages - 1) if res: logits_sums, max_equals = list(zip(*res)) logits_sums = torch.FloatTensor(logits_sums).cuda() max_equals = torch.LongTensor(max_equals).cuda() else: logits_sums = torch.zeros(res_len, dtype=torch.float32).cuda() max_equals = torch.zeros(res_len, dtype=torch.int64).cuda() torch.distributed.broadcast( tensor=logits_sums, src=src_rank, group=mpu.get_pipe_parallel_group(), ) torch.distributed.broadcast( tensor=max_equals, src=src_rank, group=mpu.get_pipe_parallel_group() ) max_equals = [bool(i) for i in max_equals.tolist()] logits_sums = logits_sums.tolist() res = list(zip(logits_sums, max_equals)) self.model.module.train_mode() # set back to train mode return reord.get_original(res)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): disable_tqdm = disable_tqdm if self.is_main else True res = [] res_len = 0 # storing the result length for later with torch.no_grad(): def _collate(x): toks = x[1] + x[2] return (-len(toks), tuple(toks)) reord = utils.Reorderer(requests, _collate) for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): inps, contlens, inplens, padding_length = [], [], [], None for _, context_enc, continuation_enc in chunk: # when too long to fit in context, truncate from the left inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1):][:-1] , dtype=torch.long).to(self.device) inplen, = inp.shape cont = continuation_enc # since in _collate we make sure length is descending, the longest is always the first one. padding_length = padding_length if padding_length is not None else inplen # pad to length inp = torch.cat([ inp, # [seq] torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] ], dim=0) inps.append(inp.unsqueeze(0)) contlens.append(cont) inplens.append(inplen) logits = self._model_call(torch.cat(inps, dim=0)) res_len += len(chunk) if logits is not None: multi_logits = F.log_softmax(logits, dim=-1) # [batch, seq, vocab] for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens): contlen = len(cont_toks) logits = logits[inplen - contlen:inplen].unsqueeze(0) # [1, seq, vocab] greedy_tokens = logits.argmax(dim=-1) # cont_toks :: [1, seq] cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0).to(multi_logits.device) max_equal = (greedy_tokens == cont_toks).all() logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] answer = (float(logits.sum()), bool(max_equal)) # partial caching if cache_key is not None: self.cache_hook.add_partial("loglikelihood", cache_key, answer) res.append(answer) # broadcast results to all ranks if self.is_pipe_parallel: src_rank = self.model.grid.stage_to_global(self.model.num_stages - 1) if res: logits_sums, max_equals = list(zip(*res)) logits_sums = torch.FloatTensor(logits_sums).cuda() max_equals = torch.LongTensor(max_equals).cuda() else: logits_sums = torch.zeros(res_len, dtype=torch.float32).cuda() max_equals = torch.zeros(res_len, dtype=torch.int64).cuda() torch.distributed.broadcast(tensor=logits_sums, src=src_rank) torch.distributed.broadcast(tensor=max_equals, src=src_rank) max_equals = [bool(i) for i in max_equals.tolist()] logits_sums = logits_sums.tolist() res = list(zip(logits_sums, max_equals)) return reord.get_original(res)