def _prepare_sample(self, sample): if sample is None or len(sample) == 0: return None if self.cuda: sample = utils.move_to_cuda(sample) def apply_half(t): if t.dtype is torch.float32: return t.half() return t if self.args.fp16: sample = utils.apply_to_sample(apply_half, sample) return sample
def _reduce_and_log_stats(self, logging_outputs, sample_size): #with metrics.aggregate() as agg: # convert logging_outputs to CPU to avoid unnecessary # device-to-host transfers in reduce_metrics logging_outputs = utils.apply_to_sample( lambda t: t.to(device='cpu', non_blocking=True), logging_outputs) #self.task.reduce_metrics(logging_outputs, self.get_criterion()) # support legacy interface #logging_output = agg.get_smoothed_values() logging_output = logging_outputs[-1] logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output
def translate_batch(model, sids, sentences): input = [model.encode(sentence) for sentence in sentences] lengths = [len(t) for t in input] dataset = model.task.build_dataset_for_inference(input, lengths) samples = dataset.collater(dataset) samples = utils.apply_to_sample( lambda tensor: tensor.to(model.device), samples ) ids = samples['id'].cpu() generator = model.task.build_generator(model.args) translations = model.task.inference_step(generator, model.models, samples) hypos = [translation[0]['tokens'] for translation in translations] translated = [model.decode(hypo) for hypo in hypos] return OrderedDict([(sids[id], tr) for id, tr in zip(ids, translated)])
def build_sample( model, src_tokens: List[torch.LongTensor], tgt_tokens: List[torch.LongTensor], ): # assert torch.is_tensor(src_tokens) dataset = LanguagePairDataset( src_tokens, [x.numel() for x in src_tokens], model.task.source_dictionary, tgt=tgt_tokens, tgt_sizes=[x.numel() for x in tgt_tokens], tgt_dict=model.task.target_dictionary, ) sample = dataset.collater(dataset) sample = utils.apply_to_sample(lambda tensor: tensor.to(model.device), sample) return sample
def start(self, start_with_nothing): state = LMState() prefix = torch.LongTensor([[self.dictionary.eos()]]) incremental_state = {} if self.save_incremental else None with torch.no_grad(): res = self.model(prefix.cuda(), incremental_state=incremental_state) probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) if incremental_state is not None: incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state) self.states[state] = FairseqLMState(prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()) self.stateq.append(state) return state
def _prepare_sample(self, sample, dummy=False): if sample is None or len(sample) == 0: return None if self.args.task == 'doc_translation' and not dummy: sample = self._prepare_sample_with_context(sample) if self.cuda: sample = utils.move_to_cuda(sample) def apply_half(t): if t.dtype is torch.float32: return t.half() return t if self.args.fp16: sample = utils.apply_to_sample(apply_half, sample) return sample
def prepare_sample(args, task, sample, use_cuda=True): if sample is None or len(sample) == 0: return None if args.task == 'doc_translation': sample = prepare_sample_with_context(task, sample) if use_cuda: sample = utils.move_to_cuda(sample) def apply_half(t): if t.dtype is torch.float32: return t.half() return t if args.fp16: sample = utils.apply_to_sample(apply_half, sample) return sample
def _build_sample(self, src_tokens: List[torch.LongTensor], src_sent_ids=None, chains_dataset=None): # assert torch.is_tensor(src_tokens) if src_sent_ids is not None: dataset = self.task.build_dataset_for_inference( src_tokens, [x.numel() for x in src_tokens], src_sent_ids=src_sent_ids, chains_dataset=chains_dataset, explicit_str_att=chains_dataset is not None) else: dataset = self.task.build_dataset_for_inference( src_tokens, [x.numel() for x in src_tokens], ) sample = dataset.collater(dataset) sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample) return sample
def _build_sample(self, src_tokens: List[torch.LongTensor], src_tokens2 = None): # assert torch.is_tensor(src_tokens) if src_tokens2 == None: dataset = self.task.build_dataset_for_inference( src_tokens, [x.numel() for x in src_tokens], ) else: dataset = self.task.build_dataset_for_inference( src_tokens, [x.numel() for x in src_tokens], src_tokens2, [x.numel() for x in src_tokens2], ) #print(self.device) sample = dataset.collater(dataset) sample = utils.apply_to_sample( lambda tensor: tensor.to(self.device), sample ) return sample
def _prepare_sample(self, sample): if sample == "DUMMY": raise Exception( "Trying to use an uninitialized 'dummy' batch. This usually indicates " "that the total number of batches is smaller than the number of " "participating GPUs. Try reducing the batch size or using fewer GPUs." ) if sample is None or len(sample) == 0: return None if self.cuda: sample = utils.move_to_cuda(sample) def apply_half(t): if t.dtype is torch.float32: return t.half() return t if self.args.fp16: sample = utils.apply_to_sample(apply_half, sample) return sample
def batch_augments(self, sentences, batch_size=30, progress_bar=True): self.from_model.eval() self.to_model.eval() result = [] oom = False batch_ind = 0 iterator = tqdm(range(len(sentences) // batch_size + 1)) if progress_bar else range( len(sentences) // batch_size + 1) try: for batch_ind in iterator: inputs = [ self.from_model.encode(sample) for sample in sentences[batch_ind * batch_size:(batch_ind + 1) * batch_size] ] if len(inputs) > 0: dataset = self.from_model.task.build_dataset_for_inference( inputs, [input.numel() for input in inputs]) sample = dataset.collater(dataset) sample = utils.apply_to_sample( lambda tensor: tensor.to(self.from_model.device), sample) gen_args = copy.copy(self.from_model.args) gen_args.beam = self.from_num_beam generator = self.from_model.task.build_generator( self.from_model.models, args=gen_args) translations = self.from_model.task.inference_step( generator, self.from_model.models, sample) translations = [ self.from_model.decode(tr[0]['tokens']) for tr in translations ] translations = [ translations[sample['id'].tolist().index(i)] for i in range(len(translations)) ] translations = [ self.to_model.encode(sample) for sample in translations ] dataset = self.to_model.task.build_dataset_for_inference( translations, [input.numel() for input in translations]) sample = dataset.collater(dataset) sample = utils.apply_to_sample( lambda tensor: tensor.to(self.to_model.device), sample) gen_args = copy.copy(self.to_model.args) gen_args.beam = self.to_num_beam generator = self.to_model.task.build_generator( self.to_model.models, args=gen_args) back_translations = self.to_model.task.inference_step( generator, self.to_model.models, sample) back_translations = [ self.to_model.decode(tr[0]['tokens']) for tr in back_translations ] back_translations = [ back_translations[sample['id'].tolist().index(i)] for i in range(len(back_translations)) ] result.extend(back_translations) except RuntimeError: torch.cuda.empty_cache() gc.collect() oom = True if oom: result.extend( self.batch_augments( sentences[batch_ind * batch_size:(batch_ind + 1) * batch_size], batch_size=batch_size // 2, progress_bar=False)) result.extend( self.batch_augments(sentences[(batch_ind + 1) * batch_size:], batch_size=batch_size)) return result
def score(self, state: LMState, token_index: int, no_cache: bool = False): """ Evaluate language model based on the current lm state and new word Parameters: ----------- state: current lm state token_index: index of the word (can be lexicon index then you should store inside LM the mapping between indices of lexicon and lm, or lm index of a word) Returns: -------- (LMState, float): pair of (new state, score for the current word) """ curr_state = self.states[state] def trim_cache(targ_size): while len(self.stateq) > targ_size: rem_k = self.stateq.popleft() rem_st = self.states[rem_k] rem_st = FairseqLMState(rem_st.prefix, None, None) self.states[rem_k] = rem_st if curr_state.probs is None: new_incremental_state = ( curr_state.incremental_state.copy() if curr_state.incremental_state is not None else None ) with torch.no_grad(): if new_incremental_state is not None: new_incremental_state = apply_to_sample( lambda x: x.cuda(), new_incremental_state ) elif self.save_incremental: new_incremental_state = {} res = self.model( torch.from_numpy(curr_state.prefix).cuda(), incremental_state=new_incremental_state, ) probs = self.model.get_normalized_probs( res, log_probs=True, sample=None ) if new_incremental_state is not None: new_incremental_state = apply_to_sample( lambda x: x.cpu(), new_incremental_state ) curr_state = FairseqLMState( curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy() ) if not no_cache: self.states[state] = curr_state self.stateq.append(state) score = curr_state.probs[token_index].item() trim_cache(self.max_cache) outstate = state.child(token_index) if outstate not in self.states and not no_cache: prefix = np.concatenate( [curr_state.prefix, torch.LongTensor([[token_index]])], -1 ) incr_state = curr_state.incremental_state self.states[outstate] = FairseqLMState(prefix, incr_state, None) if token_index == self.unk: score = float("-inf") return outstate, score
def generate( self, tokenized_sentences: List[torch.LongTensor], beam: int = 5, verbose: bool = False, skip_invalid_size_inputs=False, inference_step_args=None, **kwargs ) -> List[List[Dict[str, torch.Tensor]]]: if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: return self.generate( tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs )[0] # build generator using current args as well as any kwargs gen_args = copy.deepcopy(self.cfg.generation) with open_dict(gen_args): gen_args.beam = beam for k, v in kwargs.items(): setattr(gen_args, k, v) generator = self.task.build_generator(self.models, gen_args) inference_step_args = inference_step_args or {} results = [] for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) translations = self.task.inference_step( generator, self.models, batch, **inference_step_args ) for id, hypos in zip(batch["id"].tolist(), translations): results.append((id, hypos)) # sort output to match input order outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] if verbose: def getarg(name, default): return getattr(gen_args, name, getattr(self.cfg, name, default)) for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): src_str_with_unk = self.src_dict.string(source_tokens) logger.info("S\t{}".format(src_str_with_unk)) for hypo in target_hypotheses: hypo_str = self.decode(hypo["tokens"]) logger.info("H\t{}\t{}".format(hypo["score"], hypo_str)) logger.info( "P\t{}".format( " ".join( map( lambda x: "{:.4f}".format(x), hypo["positional_scores"].tolist(), ) ) ) ) if hypo["alignment"] is not None and getarg( "print_alignment", False ): logger.info( "A\t{}".format( " ".join( [ "{}-{}".format(src_idx, tgt_idx) for src_idx, tgt_idx in hypo["alignment"] ] ) ) ) return outputs
def variable_beam_stream_fast(sg, model, tokenized_sentences, k=5, max_length=100, rp=0.6, ap=2.5, rpl=0.02, mc=3, find_top_z=1, max_indices=32, encode_batch_size=64, max_si_tokens=7168, bos_token=None, len_penalty=1, one_batch=False): ensemble_size = len(model.models) BOS_ID = sg.eos if bos_token is None else bos_token EOS_ID = sg.eos if one_batch: full_data_size = tokenized_sentences['net_input']['src_tokens'].shape[ 0] else: full_data_size = len(tokenized_sentences) batch_iterator = model._build_batches(tokenized_sentences, False) # not streaming master_done_beams = [[] for _ in range(full_data_size)] master_batch_ids = [None for _ in range(full_data_size)] parent_model = model model = model.models master_decoded_indices = torch.zeros(1, 0, k).long().to( parent_model.device) # seq, batch, k master_log_probs = torch.zeros(0, k).to(parent_model.device) # batch x k master_enc_out = [] master_state = IncrementalState( 0, k, ensemble_size, parent_model.device) # init incremental state master_valid_beam_mask = torch.zeros(0, k).to( parent_model.device) # batch x k master_num_valid_beams = torch.zeros(0).long().to( parent_model.device) # batch master_index = torch.zeros(0).long().to(parent_model.device) # batch master_src_lengths = torch.zeros(0).long().to(parent_model.device) master_progress = torch.zeros(0).long().to(parent_model.device) # batch master_end_found = torch.zeros(0, k).long().to( parent_model.device) # batch x k master_done_lengths = torch.zeros(0).long().to( parent_model.device) # batch master_best_finished_log_probs = torch.zeros(0).to( parent_model.device) - 1e8 # batch current_idx = 0 has_more_batches = True decode_calls = 0 n_expansions = 0 master_remove_indices = torch.zeros(0).long().to(parent_model.device) num_pad = 0 reselect = True while True: while has_more_batches and master_src_lengths.sum( ) <= max_si_tokens - parent_model.args.max_tokens: # token-based limit assert reselect if one_batch: # not streaming batch = tokenized_sentences has_more_batches = False else: try: batch = next(batch_iterator) except StopIteration: has_more_batches = False break batch = utils.apply_to_sample(lambda t: t.to(parent_model.device), batch) for i, id in enumerate(batch['id'].tolist()): master_batch_ids[current_idx + i] = id net_input = batch["net_input"] src_tokens = net_input["src_tokens"] num_new_sources = len(src_tokens) # encode add the next batch of source infos; update the index encoder_outs = sg.model.forward_encoder(net_input) # concatenate to the current master tensors # decoded_indices; note these are left padded current_seqlen = master_decoded_indices.size(0) master_decoded_indices = torch.cat([ master_decoded_indices, pad_to_length(torch.zeros(1, num_new_sources, k) + BOS_ID, current_seqlen, 0, side='left', value=0).long().to(parent_model.device) ], dim=1) # log_probs master_log_probs = torch.cat([ master_log_probs, torch.cat([ torch.zeros(num_new_sources, 1), torch.zeros(num_new_sources, k - 1) - 1e8 ], dim=1).to(parent_model.device) ], dim=0) if len(master_enc_out) == 0: assert current_idx == 0 master_enc_out = encoder_outs else: assert len(master_enc_out) == len(encoder_outs) for i in range(len(master_enc_out)): meo, eo = master_enc_out[i], encoder_outs[i] max_seq = max(meo.encoder_out.shape[0], eo.encoder_out.shape[0]) new_eo = EncoderOut(encoder_out=torch.cat([ pad_to_length( meo.encoder_out, max_seq, 0, side='left', value=0), pad_to_length( eo.encoder_out, max_seq, 0, side='left', value=0) ], dim=1), encoder_padding_mask=torch.cat([ pad_to_length( meo.encoder_padding_mask, max_seq, 1, side='left', value=True), pad_to_length( eo.encoder_padding_mask, max_seq, 1, side='left', value=True) ], dim=0), encoder_embedding=torch.cat([ pad_to_length( meo.encoder_embedding, max_seq, 1, side='left', value=0), pad_to_length(eo.encoder_embedding, max_seq, 1, side='left', value=0) ], dim=0), encoder_states=None, src_tokens=None, src_lengths=None) master_enc_out[i] = new_eo if not one_batch: # get the encoder attention keys sg.model.incremental_states = [{} for _ in range(ensemble_size)] sg.model.forward_decoder( (torch.zeros(num_new_sources) + BOS_ID).long().to( parent_model.device).unsqueeze(1), encoder_outs, sg.temperature) dummy_state = sg.model.incremental_states master_state.append_new_incremental_state( num_new_sources, dummy_state, torch.arange(num_new_sources).long().to( parent_model.device) + current_idx) master_valid_beam_mask = torch.cat([ master_valid_beam_mask, torch.cat([ torch.ones(num_new_sources, 1), torch.zeros(num_new_sources, k - 1) ], dim=1).to(parent_model.device) ], dim=0) # print(net_input['src_lengths'].max()) master_src_lengths = torch.cat( [master_src_lengths, net_input['src_lengths']], dim=0) # num_valid_beams master_num_valid_beams = torch.cat([ master_num_valid_beams, torch.ones(num_new_sources).long().to(parent_model.device) ], dim=0) # index master_index = torch.cat([ master_index, current_idx + torch.arange(num_new_sources).to(parent_model.device) ], dim=0) # progress master_progress = torch.cat([ master_progress, torch.zeros(num_new_sources).long().to(parent_model.device) ], dim=0) # end_found master_end_found = torch.cat([ master_end_found, torch.zeros(num_new_sources, k).long().to(parent_model.device) ], dim=0) # done lengths master_done_lengths = torch.cat([ master_done_lengths, torch.zeros(num_new_sources).long().to(parent_model.device) ], dim=0) # best done log probs master_best_finished_log_probs = torch.cat([ master_best_finished_log_probs, torch.zeros(num_new_sources).to(parent_model.device) - 1e8 ], dim=0) current_idx += num_new_sources # break # for debugging # break if none left if not has_more_batches and len(master_index) == 0: break # based on max_bs and source_info, select which indices to use (sort source_info), then create: selected_indices, unselected_indices, prog_min = select_source_indices( master_num_valid_beams, master_progress, master_index, max_indices, reverse=False, sort=False) if one_batch: assert len(unselected_indices) == 0 # for debugging selected_master_indices = master_index[selected_indices] batch_size = len(selected_indices) selected_enc_out = sg.model.reorder_encoder_out( master_enc_out, selected_indices.unsqueeze(1).expand(-1, k).flatten()) # if decode_calls % 50 == 0: # print(decode_calls) valid_beam_mask = master_valid_beam_mask[selected_indices] valid_beam_indices = valid_beam_mask.flatten().nonzero().flatten( ) # idk why need to flatten again reverse_idx = (torch.cumsum(valid_beam_mask.flatten( ), dim=0) * valid_beam_mask.flatten()).long( ) - 1 # it's fine to select whatever position for padding as they'll be removed later if num_pad > 0: if num_pad >= len( master_decoded_indices ): # edge case: we previously ran out of beams, and we are starting fresh now assert num_pad == len(master_decoded_indices) num_pad -= 1 master_decoded_indices = master_decoded_indices[num_pad:] master_state.clean_padding(num_pad) if reselect: selected_state_master_indices, selected_state = master_state.select_incremental_state( selected_master_indices, master_remove_indices, prog_min) master_state.num_sources -= len(master_remove_indices) sg.model.incremental_states = selected_state log_probs = master_log_probs[selected_indices] progress = master_progress[selected_indices] decoded_indices = master_decoded_indices[-progress.max() - 1:, selected_indices, :] end_found = master_end_found[selected_indices] done_lengths = master_done_lengths[selected_indices] best_finished_log_probs = master_best_finished_log_probs[ selected_indices] # flattened_indices = last_indices.flatten().unsqueeze(0) # 1 x batch*k # create valid beam indices from valid beam mask if one_batch and decode_calls == 0: selected_state_master_indices = master_index.clone() assert len(selected_state_master_indices) == len(valid_beam_indices) decode_calls += 1 n_expansions += len(valid_beam_indices) # use valid_beam_mask to select valid indices out of decoded_indices, encoder_outs, model incremental state decoding_selected_indices = decoded_indices.flatten( 1)[:, valid_beam_indices] # seq x selected selected_enc_out = sg.model.reorder_encoder_out( selected_enc_out, valid_beam_indices) assert torch.all( decoding_selected_indices.flatten(1).permute(1, 0)[:, 0] == 2) next_log_probs, _ = sg.model.forward_decoder( decoding_selected_indices.flatten(1).permute( 1, 0)[:, :master_progress.max() + 1], selected_enc_out, sg.temperature) # remake next_scores, state with dummies next_log_probs = next_log_probs[reverse_idx].view(1, batch_size, k, -1) # reorder incremental model state reorder_idx = reverse_idx next_log_probs = next_log_probs.view(1, batch_size, k, -1) # for edge case where EOS_ID appears later down in the beam but still needs to be dealt with correctly on the next step! end_found = end_found.unsqueeze(0).unsqueeze( 3 ) # batch_size x k x 1 of whether end index is in tgt_idx already; if so, make prob of padding 1 end_found = ( end_found + (progress + 1 == max_length).long().view(1, -1, 1, 1)).clamp(max=1) end_found_scores = torch.zeros_like(next_log_probs).to( parent_model.device) - 1e8 end_found_scores[:, :, :, EOS_ID] = 0 # make it so you only pick eos for the sequences that are already done, and don't duplicate them, by making other probs -inf next_log_probs = end_found * end_found_scores + ( 1 - end_found) * next_log_probs # ~ is for inverting the mask next_log_probs = next_log_probs - 1e8 * ( 1 - valid_beam_mask.unsqueeze(0).unsqueeze(3) ) # get rid of padding positions next_log_probs = next_log_probs + log_probs.unsqueeze(0).unsqueeze( 3) # 1, batch, k, vocab mc_probs, mc_indices = next_log_probs.topk(mc, dim=3) # 1, batch, k, mc top_log_probs, top_indices = mc_probs.flatten(2).topk( k, dim=2) # 1, batch, k mc_vocab_indices = top_indices % mc beam_indices = top_indices // mc # 1, batch, k vocab_indices = torch.gather( mc_indices.flatten(2).flatten(0, 1), 1, (mc_vocab_indices + beam_indices * mc).flatten(0, 1)).unsqueeze( 0) # 1, batch, k # check which vocab_indices are done (in the first beam position), and add the corresponding beam to an array of done predictions newly_done_all = (vocab_indices == EOS_ID).long() # 1, batch, k newly_done = torch.cumprod( newly_done_all, dim=2 ) # keep on beam if there's something above it that's not done yet done_lengths += newly_done.sum(dim=2).flatten( ) # update this one before others since we'll need it earlier newly_done_indices = newly_done.flatten().nonzero() # batch*k for j in newly_done_indices: source_idx = j // k # add to some master list with an entry for each source if len(master_done_beams[ selected_master_indices[source_idx]]) < find_top_z: finished_cand = decoded_indices[:, source_idx, beam_indices[0, source_idx, j % k]].flatten() finished_cand_length = progress[source_idx] + 1 while len(finished_cand) > 0 and finished_cand[-1] == EOS_ID: finished_cand = finished_cand[:-1] finished_cand_length -= 1 if len(finished_cand) > 0: # avoid length 0 master_done_beams[selected_master_indices[source_idx]].append( \ {'tokens': finished_cand.cpu(), 'score': (top_log_probs.flatten()[j] / ((finished_cand_length)**len_penalty)).item() }) best_finished_log_probs[source_idx] = max( best_finished_log_probs[source_idx], top_log_probs.flatten()[j]) else: # rarely with greedy search (beam size k = 1) you get stuff with length 0... so avoid crashing but give it low score master_done_beams[selected_master_indices[source_idx]].append( \ {'tokens': finished_cand.cpu(), 'score': -1e8 }) # then, shift log_probs and beam_indices for those beams and delete that beam(s); put in placeholder beam and log_prob at the k^th position # need to shift top_log_probs, beam_indices, vocab_indices accordingly top_log_probs = torch.cat([ top_log_probs, torch.zeros_like(top_log_probs).to(parent_model.device) - 1e8 ], dim=2) # 1, batch, 2k shift_indices = newly_done.sum( dim=2).unsqueeze(2) + torch.arange(k).to( parent_model.device).unsqueeze(0).unsqueeze(1) # 1, batch, k top_log_probs = torch.gather(top_log_probs, 2, shift_indices) shift_indices = shift_indices.clamp(max=k - 1) beam_indices = torch.gather(beam_indices, 2, shift_indices) vocab_indices = torch.gather(vocab_indices, 2, shift_indices) newly_done_all = torch.gather(newly_done_all, 2, shift_indices) log_probs = top_log_probs.squeeze(0) state_indices = (beam_indices + k * torch.arange(batch_size).to( parent_model.device).unsqueeze(1).repeat(1, k)).flatten() reorder_idx = reorder_idx[state_indices] # update valid beam mask ap_thresholds = (torch.max(log_probs[:, 0], best_finished_log_probs) - ap).unsqueeze(1) # batch x 1 valid_beam_mask = (log_probs > ap_thresholds).float() # batch x k # update valid beam mask based on how many beams are left for each source done_mask = pad_mask( k - done_lengths, parent_model.device, max_seqlen=k).permute( 1, 0) # batch x k of beams to keep, up to k - num done already all_low_prob_mask = 1 - valid_beam_mask.max( dim=1 )[0] # NOTE since we filter out by the absolute threshold including previously finished beams, we could get < k finished candidates, but always at least 1 found_z_mask = (all_low_prob_mask.bool() | (done_lengths >= find_top_z)).unsqueeze(1) valid_beam_mask = valid_beam_mask * done_mask * (1 - found_z_mask.long()) # filter the done ones out of all the master tensors keep_indices = (~found_z_mask).flatten().nonzero().flatten().long() remove_indices = (found_z_mask).flatten().nonzero().flatten().long() keep_indices = torch.cat( [selected_indices[keep_indices], unselected_indices], dim=0) master_remove_indices = master_index[selected_indices[remove_indices]] # update these quantities in their respective source_info objects after computing them # just deleting/concatenating to a single master tensor # master_decoded_indices seq x batch x k new_master_indices = torch.zeros( 1, master_decoded_indices.size(1), k).long().to(parent_model.device) # 1 x batch x k new_master_indices[:, selected_indices] = vocab_indices master_decoded_indices[:, selected_indices] = torch.gather( master_decoded_indices[:, selected_indices], 2, beam_indices.expand( master_decoded_indices[:, selected_indices].shape)) master_decoded_indices = torch.cat( [master_decoded_indices, new_master_indices], dim=0) if prog_min + 2 >= master_decoded_indices.shape[0]: master_decoded_indices = torch.cat([ torch.zeros(1, master_decoded_indices.size(1), k).long().to( parent_model.device), master_decoded_indices ], dim=0) master_decoded_indices[:, selected_indices] = torch.roll( master_decoded_indices[:, selected_indices], -1, 0) master_decoded_indices = master_decoded_indices[:-1] # master_log_probs batch x k master_log_probs[selected_indices] = log_probs # master_valid_beam_mask batch x k master_valid_beam_mask[selected_indices] = valid_beam_mask # master_num_valid_beams batch master_num_valid_beams = master_valid_beam_mask.sum(dim=1).long() # master_progress batch master_progress[selected_indices] += 1 # master_end_found batch x k master_end_found[selected_indices] = ( torch.gather(end_found.squeeze(3), 2, beam_indices) | newly_done_all[0, :, :]).squeeze(0) # master_done_lengths batch master_done_lengths[selected_indices] = done_lengths # master_best_finished_log_probs batch master_best_finished_log_probs[ selected_indices] = best_finished_log_probs # update master versions of sg.model state reorder_idx = reorder_idx[ valid_beam_mask.flatten().nonzero().flatten()] selected_state_master_indices = selected_state_master_indices[ reorder_idx] reorder_incremental_state(sg.model, reorder_idx) master_src_lengths = master_src_lengths[keep_indices] if master_src_lengths.sum( ) <= max_si_tokens - parent_model.args.max_tokens: reselect = True elif len(progress) < (master_progress == prog_min + 1).sum(): reselect = True else: reselect = False if reselect: # if not one_batch: # print('reselect', decode_calls) master_state.recache(selected_state_master_indices, sg.model.incremental_states) master_decoded_indices = master_decoded_indices[:, keep_indices, :] master_log_probs = master_log_probs[keep_indices] master_enc_out = sg.model.reorder_encoder_out(master_enc_out, keep_indices) master_valid_beam_mask = master_valid_beam_mask[keep_indices] master_num_valid_beams = master_num_valid_beams[keep_indices] master_index = master_index[keep_indices] master_progress = master_progress[keep_indices] master_end_found = master_end_found[keep_indices] master_done_lengths = master_done_lengths[keep_indices] master_best_finished_log_probs = master_best_finished_log_probs[ keep_indices] # delete any unnecessary padding so we don't keep increasing padding num_pad = (master_decoded_indices.sum(dim=1).sum(dim=1) == 0).sum( dim=0) if not reselect: assert num_pad == 0 assert all([bid is not None for bid in master_batch_ids]) for i in range(len(master_done_beams)): master_done_beams[i] = sorted(master_done_beams[i], key=lambda x: x['score'], reverse=True) if one_batch: return master_done_beams, decode_calls, n_expansions else: return master_batch_ids, master_done_beams, decode_calls, n_expansions
def main(args, task=None, model_state=None): check_args(args) use_fp16 = args.fp16 if args.max_tokens is None and args.batch_size is None: args.max_tokens = 4000000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu logger.info("| decoding with criterion {}".format(args.criterion)) task = tasks.setup_task(args) # Load ensemble if args.load_emissions: models, criterions = [], [] task.load_dataset(args.gen_subset) else: logger.info("| loading model(s) from {}".format(args.path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( utils.split_paths(args.path, separator="\\"), arg_overrides=ast.literal_eval(args.model_overrides), task=task, suffix=args.checkpoint_suffix, strict=(args.checkpoint_shard_count == 1), num_shards=args.checkpoint_shard_count, state=model_state, ) optimize_models(args, use_cuda, models) task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) # Set dictionary tgt_dict = task.target_dictionary logger.info("| {} {} {} examples".format( args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": raise NotImplementedError("asg_loss is currently not supported") # trans = criterions[0].asg.trans.data # args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) itr = get_dataset_itr(args, task, models) # Initialize generator gen_timer = StopwatchMeter() def build_generator(args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, task.target_dictionary) elif w2l_decoder == "kenlm": from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(args, task.target_dictionary) elif w2l_decoder == "fairseqlm": from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder return W2lFairseqLMDecoder(args, task.target_dictionary) else: print( "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" ) # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task generator = build_generator(args) if args.load_emissions: generator = ExistingEmissionsDecoder( generator, np.load(args.load_emissions, allow_pickle=True)) logger.info("loaded emissions from " + args.load_emissions) num_sentences = 0 if args.results_path is not None and not os.path.exists(args.results_path): os.makedirs(args.results_path) max_source_pos = (utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ) if max_source_pos is not None: max_source_pos = max_source_pos[0] if max_source_pos is not None: max_source_pos = max_source_pos[0] - 1 if args.dump_emissions: emissions = {} if args.dump_features: features = {} models[0].bert.proj = None else: res_files = prepare_result_files(args) errs_t = 0 lengths_t = 0 with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if use_fp16: sample = utils.apply_to_sample(apply_half, sample) if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, :args.prefix_size] gen_timer.start() if args.dump_emissions: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) emm = models[0].get_normalized_probs(encoder_out, log_probs=True) emm = emm.transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): emissions[id.item()] = emm[i] continue elif args.dump_features: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) feat = encoder_out["encoder_out"].transpose( 0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): padding = (encoder_out["encoder_padding_mask"][i].cpu( ).numpy() if encoder_out["encoder_padding_mask"] is not None else None) features[id.item()] = (feat[i], padding) continue hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): speaker = None # id = task.dataset(args.gen_subset).ids[int(sample_id)] id = sample_id toks = (sample["target"][i, :] if "target_label" not in sample else sample["target_label"][i, :]) target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() # Process top predictions errs, length = process_predictions( args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id, ) errs_t += errs lengths_t += length wps_meter.update(num_generated_tokens) t.log({"wps": round(wps_meter.avg)}) num_sentences += (sample["nsentences"] if "nsentences" in sample else sample["id"].numel()) wer = None if args.dump_emissions: emm_arr = [] for i in range(len(emissions)): emm_arr.append(emissions[i]) np.save(args.dump_emissions, emm_arr) logger.info( f"saved {len(emissions)} emissions to {args.dump_emissions}") elif args.dump_features: feat_arr = [] for i in range(len(features)): feat_arr.append(features[i]) np.save(args.dump_features, feat_arr) logger.info(f"saved {len(features)} emissions to {args.dump_features}") else: if lengths_t > 0: wer = errs_t * 100.0 / lengths_t logger.info(f"WER: {wer}") logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" "sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, )) logger.info("| Generate {} with beam={}".format( args.gen_subset, args.beam)) return task, wer
def custom_eval(model, src, trg, beam=5, ap=math.inf, eps=1. / 6, mc=None, method=None): model.eval() with torch.no_grad(): tokenized_sentences = [model.encode((sentence)) for sentence in src] gen_args = copy.copy(model.args) gen_args.beam = beam gen_args.mc = mc generator = build_generator(model.task, model.models, gen_args) results = [] # model.args.max_sentences = 64 total_loops, total_expansions = 0, 0 if method == 'variable_stream': # TODO adjust other parameters; adjust batching params ids, translations, total_loops, total_expansions = generator.variable_beam_stream( model, tokenized_sentences, bos_token=model.task.target_dictionary.eos(), ap=ap, mc=mc, eps=eps) for id, hypos in zip(ids, translations): results.append((id, hypos)) else: for batch in model._build_batches(tokenized_sentences, False): # print('b') batch = utils.apply_to_sample(lambda t: t.to(model.device), batch) if method is None: translations, n_loops, n_expansions = generator.generate( model.models, batch, bos_token=model.task.target_dictionary.eos(), ap=ap) elif method == 'greedy': translations, n_loops, n_expansions = generator.greedy( model.models, batch, bos_token=model.task.target_dictionary.eos()) elif method == 'variable_beam': translations, n_loops, n_expansions = generator.variable_beam( model, batch, bos_token=model.task.target_dictionary.eos(), ap=ap, mc=mc) total_loops += n_loops total_expansions += n_expansions for id, hypos in zip(batch["id"].tolist(), translations): results.append((id, hypos)) # sort output to match input order outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] predictions = [model.decode(hypos[0]['tokens']) for hypos in outputs] bleu = sacrebleu.corpus_bleu(predictions, [trg]).score # print(predictions) print('loops', total_loops) print('expansions', total_expansions) print(bleu) return bleu
def move_to_device(sample, device): def _move_to_device(tensor): return tensor.to(device=device) return utils.apply_to_sample(_move_to_device, sample)
def main(args, override_args=None): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu if override_args is not None: try: override_args = override_args['override_args'] except TypeError: override_args = override_args overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) else: overrides = None # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, suffix=getattr(args, "checkpoint_suffix", ""), ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(model_args) # Build criterion criterion = task.build_criterion(model_args) if use_fp16: criterion.half() if use_cuda: criterion.cuda() criterion.eval() for subset in args.valid_subset.split(','): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_workers=args.num_workers, num_shards=args.distributed_world_size, shard_id=args.distributed_rank).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample sample = utils.apply_to_sample( lambda t: t.half() if t.dtype is torch.float32 else t, sample) if use_fp16 else sample try: with torch.no_grad(): # do not save backward passes max_num_rays = 900 * 900 if sample['uv'].shape[3] > max_num_rays: sample['ray_split'] = sample['uv'].shape[ 3] // max_num_rays _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) except TypeError: break with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() # summarize all the gpus if args.distributed_world_size > 1: all_log_output = list( zip(*distributed_utils.all_gather_list([log_output])))[0] log_output = { key: np.mean([log[key] for log in all_log_output]) for key in all_log_output[0] } progress.print(log_output, tag=subset, step=i)