def display_token(self, viz_id, token_id, position): raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0] clean_token = self.tokenizer.decode(token_id) # Strip prefixes because bert decode still has ## for partials even after decode() clean_token = strip_tokenizer_prefix(self.model_config, clean_token) token = { # 'token': self.tokenizer.decode([token_id]), 'token': clean_token, 'is_partial': is_partial_token(self.model_config, raw_token), 'token_id': int(token_id), 'position': position, 'type': 'output' } js = f""" // We don't really need these require scripts. But this is to avert //this code from running before display_input_sequence which DOES require external files requirejs(['basic', 'ecco'], function(basic, ecco){{ console.log('addToken viz_id', '{viz_id}'); window.ecco['{viz_id}'].addToken({json.dumps(token)}) window.ecco['{viz_id}'].redraw() }}) """ # print(js) d.display(d.Javascript(js))
def display_input_sequence(self, input_ids): tokens = [] for idx, token_id in enumerate(input_ids): type = "input" raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0] clean_token = self.tokenizer.decode(token_id) # Strip prefixes because bert decode still has ## for partials even after decode() clean_token = strip_tokenizer_prefix(self.model_config, clean_token) tokens.append({ # 'token': self.tokenizer.decode([token_id]), 'token': clean_token, 'is_partial': is_partial_token(self.model_config, raw_token), 'position': idx, 'token_id': int(token_id), 'type': type}) data = {'tokens': tokens} d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html"))) viz_id = f'viz_{round(random.random() * 1000000)}' # TODO: Stop passing tokenization_config to JS now that # it's handled with the is_partial parameter js = f""" requirejs( ['basic', 'ecco'], function(basic, ecco){{ basic.init('{viz_id}') // Python needs to know the viz id. Used for each output token. window.ecco['{viz_id}'] = new ecco.renderOutputSequence({{ parentDiv: '{viz_id}', data: {json.dumps(data)}, tokenization_config: {json.dumps(self.model_config['tokenizer_config'])} }}) }}, function (err) {{ console.log(err); }}) """ d.display(d.Javascript(js)) return viz_id
def rankings_watch(self, watch: List[int] = None, position: int = -1, **kwargs): """ Plots the rankings of the tokens whose ids are supplied in the watch list. Only considers one position. ![Rankings plot](../../img/ranking_watch_ex_is_are_1.png) """ assert self.model_type != 'mlm', "method not supported for Masked-LMs" _, dec_hidden_states = self._get_hidden_states() assert dec_hidden_states is not None, "decoder hidden states not found" if position != -1: if self.model_type in ['enc-dec', 'causal']: # The position is relative. By that means, position self.n_input_tokens + 1 is the first generated token offset = 1 if self.model_type == 'enc-dec' else 0 new_position = position - offset - self.n_input_tokens assert new_position >= 0, f"position={position} not supported, minimum is " \ f"position={self.n_input_tokens + offset} for the first generated token" assert new_position < len(dec_hidden_states), f"position={position} not supported, maximum is " \ f"position={len(dec_hidden_states) - 1 + self.n_input_tokens + offset} " \ f"for the last generated token." position = new_position else: raise NotImplemented( f"model_type={self.model_type} not supported") dec_hidden_states = dec_hidden_states[position][:, -1, :] n_layers_dec = len( dec_hidden_states) if dec_hidden_states is not None else 0 n_tokens_to_watch = len(watch) rankings = np.zeros((n_layers_dec, n_tokens_to_watch), dtype=np.int32) # loop through layer levels for i, level in enumerate(dec_hidden_states): # Loop through generated/output positions for j, token_id in enumerate(watch): # Project hidden state to vocabulary # (after debugging pain: ensure input is on GPU, if appropriate) logits = self.lm_head(self.to(level)) # Sort by score (ascending) sorted = torch.argsort(logits) # What token was sampled in this position? token_id = torch.tensor(token_id) # What's the index of the sampled token in the sorted list? r = torch.nonzero((sorted == token_id)).flatten() # subtract to get ranking (where 1 is the top scoring, because sorting was in ascending order) ranking = sorted.shape[0] - r rankings[i, j] = int(ranking) input_tokens = [ strip_tokenizer_prefix(self.config, t) for t in self.tokens[0] ] output_tokens = [repr(self.tokenizer.decode(t)) for t in watch] lm_plots.plot_inner_token_rankings_watch( input_tokens, output_tokens, rankings, position + self.n_input_tokens if self.model_type == 'enc-dec' else position) if 'printJson' in kwargs and kwargs['printJson']: data = { 'input_tokens': input_tokens, 'output_tokens': output_tokens, 'rankings': rankings } print(data) return data
def rankings(self, **kwargs): """ Plots the rankings (across layers) of the tokens the model selected. Each column is a position in the sequence. Each row is a layer. ![Rankings watch](../../img/rankings_ex_eu_1.png) """ assert self.model_type != 'mlm', "method not supported for Masked-LMs" _, dec_hidden_states = self._get_hidden_states() assert dec_hidden_states is not None, "decoder hidden states not found" n_layers_dec = dec_hidden_states[0].shape[0] position = len(dec_hidden_states) rankings = np.zeros((n_layers_dec, position), dtype=np.int32) predicted_tokens = np.empty((n_layers_dec, position), dtype='U25') token_found_mask = np.ones((n_layers_dec, position)) # loop through tokens hidden states for j, token_hidden_states in enumerate(dec_hidden_states): # Loop through generated/output positions for i, hidden_state in enumerate(token_hidden_states[:, -1, :]): # Project hidden state to vocabulary # (after debugging pain: ensure input is on GPU, if appropriate) logits = self.lm_head(self.to(hidden_state)) # Sort by score (ascending) sorted = torch.argsort(logits) # What token was sampled in this position? offset = self.n_input_tokens + 1 if self.model_type == 'enc-dec' else self.n_input_tokens token_id = torch.tensor(self.token_ids[0][offset + j]) # token_id = self.token_ids.clone().detach()[self.n_input_tokens + j] # What's the index of the sampled token in the sorted list? r = torch.nonzero((sorted == token_id)).flatten() # subtract to get ranking (where 1 is the top scoring, because sorting was in ascending order) ranking = sorted.shape[0] - r token = self.tokenizer.decode([token_id]) predicted_tokens[i, j] = token rankings[i, j] = int(ranking) if token_id == self.token_ids[0][j + 1]: token_found_mask[i, j] = 0 input_tokens = [ repr(strip_tokenizer_prefix(self.config, t)) for t in self.tokens[0][self.n_input_tokens - 1:-1] ] offset = self.n_input_tokens + 1 if self.model_type == 'enc-dec' else self.n_input_tokens output_tokens = [ repr(strip_tokenizer_prefix(self.config, t)) for t in self.tokens[0][offset:] ] lm_plots.plot_inner_token_rankings(input_tokens, output_tokens, rankings, **kwargs) if 'printJson' in kwargs and kwargs['printJson']: data = { 'input_tokens': input_tokens, 'output_tokens': output_tokens, 'rankings': rankings, 'predicted_tokens': predicted_tokens, 'token_found_mask': token_found_mask } print(data) return data
def primary_attributions(self, attr_method: Optional[str] = 'grad_x_input', style="minimal", ignore_tokens: Optional[List[int]] = [], **kwargs): """ Explorable showing primary attributions of each token generation step. Hovering-over or tapping an output token imposes a saliency map on other tokens showing their importance as features to that prediction. Examples: ```python import ecco lm = ecco.from_pretrained('distilgpt2') text= "The countries of the European Union are:\n1. Austria\n2. Belgium\n3. Bulgaria\n4." output = lm.generate(text, generate=20, do_sample=True) # Show primary attributions explorable output.primary_attributions() ``` Which creates the following interactive explorable: ![input saliency example 1](../../img/saliency_ex_1.png) If we want more details on the saliency values, we can use the detailed view: ```python # Show detailed explorable output.primary_attributions(style="detailed") ``` Which creates the following interactive explorable: ![input saliency example 2 - detailed](../../img/saliency_ex_2.png) Details: This view shows the Gradient * Inputs method of input saliency. The attribution values are calculated across the embedding dimensions, then we use the L2 norm to calculate a score for each token (from the values of its embeddings dimension) To get a percentage value, we normalize the scores by dividing by the sum of the attribution scores for all the tokens in the sequence. """ position = self.n_input_tokens importance_id = position - self.n_input_tokens tokens = [] assert attr_method in self.attribution, \ f"attr_method={attr_method} not found. Choose one of the following: {list(self.attribution.keys())}" attribution = self.attribution[attr_method] for idx, token in enumerate(self.tokens[0]): token_id = self.token_ids[0][idx] raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0] clean_token = self.tokenizer.decode(token_id) # Strip prefixes because bert decode still has ## for partials even after decode() clean_token = strip_tokenizer_prefix(self.config, clean_token) type = "input" if idx < self.n_input_tokens else 'output' if idx < len(attribution[importance_id]): imp = attribution[importance_id][idx] else: imp = 0 tokens.append({ 'token': clean_token, 'token_id': int(self.token_ids[0][idx]), 'is_partial': is_partial_token(self.config, raw_token), 'type': type, 'value': str(imp ), # because json complains of floats. Probably not used? 'position': idx }) if len(ignore_tokens) > 0: for output_token_index, _ in enumerate(attribution): for idx in ignore_tokens: attribution[output_token_index][idx] = 0 data = { 'tokens': tokens, 'attributions': [att.tolist() for att in attribution] } d.display( d.HTML(filename=os.path.join(self._path, "html", "setup.html"))) if (style == "minimal"): js = f""" requirejs(['basic', 'ecco'], function(basic, ecco){{ const viz_id = basic.init() console.log(viz_id) // ecco.interactiveTokens(viz_id, {{}}) window.ecco[viz_id] = new ecco.MinimalHighlighter({{ parentDiv: viz_id, data: {json.dumps(data)}, preset: 'viridis', tokenization_config: {json.dumps(self.config['tokenizer_config'])} }}) window.ecco[viz_id].init(); window.ecco[viz_id].selectFirstToken(); }}, function (err) {{ console.log(err); }})""" elif (style == "detailed"): js = f""" requirejs(['basic', 'ecco'], function(basic, ecco){{ const viz_id = basic.init() console.log(viz_id) window.ecco[viz_id] = ecco.interactiveTokens({{ parentDiv: viz_id, data: {json.dumps(data)}, tokenization_config: {json.dumps(self.config['tokenizer_config'])} }}) }}, function (err) {{ console.log(err); }})""" d.display(d.Javascript(js)) if 'printJson' in kwargs and kwargs['printJson']: print(data) return data