コード例 #1
0
    def display_token(self, viz_id, token_id, position):

        raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
        clean_token = self.tokenizer.decode(token_id)
        # Strip prefixes because bert decode still has ## for partials even after decode()
        clean_token = strip_tokenizer_prefix(self.model_config, clean_token)

        token = {
            # 'token': self.tokenizer.decode([token_id]),
            'token': clean_token,
            'is_partial': is_partial_token(self.model_config, raw_token),
            'token_id': int(token_id),
            'position': position,
            'type': 'output'
        }
        js = f"""
        // We don't really need these require scripts. But this is to avert
        //this code from running before display_input_sequence which DOES require external files
        requirejs(['basic', 'ecco'], function(basic, ecco){{
                console.log('addToken viz_id', '{viz_id}');
                window.ecco['{viz_id}'].addToken({json.dumps(token)})
                window.ecco['{viz_id}'].redraw()
        }})
        """
        # print(js)
        d.display(d.Javascript(js))
コード例 #2
0
    def display_input_sequence(self, input_ids):

        tokens = []
        for idx, token_id in enumerate(input_ids):
            type = "input"
            raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
            clean_token = self.tokenizer.decode(token_id)
            # Strip prefixes because bert decode still has ## for partials even after decode()
            clean_token = strip_tokenizer_prefix(self.model_config, clean_token)
            tokens.append({
                # 'token': self.tokenizer.decode([token_id]),
                           'token': clean_token,
                           'is_partial': is_partial_token(self.model_config, raw_token),
                           'position': idx,
                           'token_id': int(token_id),
                           'type': type})
        data = {'tokens': tokens}

        d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html")))

        viz_id = f'viz_{round(random.random() * 1000000)}'

        # TODO: Stop passing tokenization_config to JS now that
        # it's handled with the is_partial parameter
        js = f"""
         requirejs( ['basic', 'ecco'], function(basic, ecco){{
            basic.init('{viz_id}') // Python needs to know the viz id. Used for each output token.
            window.ecco['{viz_id}'] = new ecco.renderOutputSequence({{
                    parentDiv: '{viz_id}',
                    data: {json.dumps(data)},
                    tokenization_config: {json.dumps(self.model_config['tokenizer_config'])}
            
            }})
         }}, function (err) {{
            console.log(err);
        }})
        """

        d.display(d.Javascript(js))
        return viz_id
コード例 #3
0
    def rankings_watch(self,
                       watch: List[int] = None,
                       position: int = -1,
                       **kwargs):
        """
        Plots the rankings of the tokens whose ids are supplied in the watch list.
        Only considers one position.

        ![Rankings plot](../../img/ranking_watch_ex_is_are_1.png)
        """

        assert self.model_type != 'mlm', "method not supported for Masked-LMs"

        _, dec_hidden_states = self._get_hidden_states()
        assert dec_hidden_states is not None, "decoder hidden states not found"

        if position != -1:
            if self.model_type in ['enc-dec', 'causal']:
                # The position is relative. By that means, position self.n_input_tokens + 1 is the first generated token
                offset = 1 if self.model_type == 'enc-dec' else 0
                new_position = position - offset - self.n_input_tokens
                assert new_position >= 0, f"position={position} not supported, minimum is " \
                                          f"position={self.n_input_tokens + offset} for the first generated token"
                assert new_position < len(dec_hidden_states), f"position={position} not supported, maximum is " \
                                                              f"position={len(dec_hidden_states) - 1 + self.n_input_tokens + offset} " \
                                                              f"for the last generated token."
                position = new_position
            else:
                raise NotImplemented(
                    f"model_type={self.model_type} not supported")

        dec_hidden_states = dec_hidden_states[position][:, -1, :]

        n_layers_dec = len(
            dec_hidden_states) if dec_hidden_states is not None else 0
        n_tokens_to_watch = len(watch)
        rankings = np.zeros((n_layers_dec, n_tokens_to_watch), dtype=np.int32)

        # loop through layer levels
        for i, level in enumerate(dec_hidden_states):
            # Loop through generated/output positions
            for j, token_id in enumerate(watch):

                # Project hidden state to vocabulary
                # (after debugging pain: ensure input is on GPU, if appropriate)
                logits = self.lm_head(self.to(level))

                # Sort by score (ascending)
                sorted = torch.argsort(logits)

                # What token was sampled in this position?
                token_id = torch.tensor(token_id)

                # What's the index of the sampled token in the sorted list?
                r = torch.nonzero((sorted == token_id)).flatten()

                # subtract to get ranking (where 1 is the top scoring, because sorting was in ascending order)
                ranking = sorted.shape[0] - r
                rankings[i, j] = int(ranking)

        input_tokens = [
            strip_tokenizer_prefix(self.config, t) for t in self.tokens[0]
        ]
        output_tokens = [repr(self.tokenizer.decode(t)) for t in watch]

        lm_plots.plot_inner_token_rankings_watch(
            input_tokens, output_tokens, rankings, position +
            self.n_input_tokens if self.model_type == 'enc-dec' else position)

        if 'printJson' in kwargs and kwargs['printJson']:
            data = {
                'input_tokens': input_tokens,
                'output_tokens': output_tokens,
                'rankings': rankings
            }
            print(data)
            return data
コード例 #4
0
    def rankings(self, **kwargs):
        """
        Plots the rankings (across layers) of the tokens the model selected.
        Each column is a position in the sequence. Each row is a layer.

        ![Rankings watch](../../img/rankings_ex_eu_1.png)
        """

        assert self.model_type != 'mlm', "method not supported for Masked-LMs"

        _, dec_hidden_states = self._get_hidden_states()
        assert dec_hidden_states is not None, "decoder hidden states not found"

        n_layers_dec = dec_hidden_states[0].shape[0]
        position = len(dec_hidden_states)

        rankings = np.zeros((n_layers_dec, position), dtype=np.int32)
        predicted_tokens = np.empty((n_layers_dec, position), dtype='U25')
        token_found_mask = np.ones((n_layers_dec, position))

        # loop through tokens hidden states
        for j, token_hidden_states in enumerate(dec_hidden_states):

            # Loop through generated/output positions
            for i, hidden_state in enumerate(token_hidden_states[:, -1, :]):

                # Project hidden state to vocabulary
                # (after debugging pain: ensure input is on GPU, if appropriate)
                logits = self.lm_head(self.to(hidden_state))

                # Sort by score (ascending)
                sorted = torch.argsort(logits)

                # What token was sampled in this position?
                offset = self.n_input_tokens + 1 if self.model_type == 'enc-dec' else self.n_input_tokens
                token_id = torch.tensor(self.token_ids[0][offset + j])

                # token_id = self.token_ids.clone().detach()[self.n_input_tokens + j]
                # What's the index of the sampled token in the sorted list?
                r = torch.nonzero((sorted == token_id)).flatten()

                # subtract to get ranking (where 1 is the top scoring, because sorting was in ascending order)
                ranking = sorted.shape[0] - r
                token = self.tokenizer.decode([token_id])
                predicted_tokens[i, j] = token
                rankings[i, j] = int(ranking)
                if token_id == self.token_ids[0][j + 1]:
                    token_found_mask[i, j] = 0

        input_tokens = [
            repr(strip_tokenizer_prefix(self.config, t))
            for t in self.tokens[0][self.n_input_tokens - 1:-1]
        ]
        offset = self.n_input_tokens + 1 if self.model_type == 'enc-dec' else self.n_input_tokens
        output_tokens = [
            repr(strip_tokenizer_prefix(self.config, t))
            for t in self.tokens[0][offset:]
        ]

        lm_plots.plot_inner_token_rankings(input_tokens, output_tokens,
                                           rankings, **kwargs)

        if 'printJson' in kwargs and kwargs['printJson']:
            data = {
                'input_tokens': input_tokens,
                'output_tokens': output_tokens,
                'rankings': rankings,
                'predicted_tokens': predicted_tokens,
                'token_found_mask': token_found_mask
            }
            print(data)
            return data
コード例 #5
0
    def primary_attributions(self,
                             attr_method: Optional[str] = 'grad_x_input',
                             style="minimal",
                             ignore_tokens: Optional[List[int]] = [],
                             **kwargs):
        """
            Explorable showing primary attributions of each token generation step.
            Hovering-over or tapping an output token imposes a saliency map on other tokens
            showing their importance as features to that prediction.

            Examples:

            ```python
            import ecco
            lm = ecco.from_pretrained('distilgpt2')
            text= "The countries of the European Union are:\n1. Austria\n2. Belgium\n3. Bulgaria\n4."
            output = lm.generate(text, generate=20, do_sample=True)

            # Show primary attributions explorable
            output.primary_attributions()
            ```

            Which creates the following interactive explorable:
            ![input saliency example 1](../../img/saliency_ex_1.png)

            If we want more details on the saliency values, we can use the detailed view:

            ```python
            # Show detailed explorable
            output.primary_attributions(style="detailed")
            ```

            Which creates the following interactive explorable:

            ![input saliency example 2 - detailed](../../img/saliency_ex_2.png)


            Details:
            This view shows the Gradient * Inputs method of input saliency. The attribution values are calculated across the
            embedding dimensions, then we use the L2 norm to calculate a score for each token (from the values of its embeddings dimension)
            To get a percentage value, we normalize the scores by dividing by the sum of the attribution scores for all
            the tokens in the sequence.
        """
        position = self.n_input_tokens

        importance_id = position - self.n_input_tokens
        tokens = []

        assert attr_method in self.attribution, \
            f"attr_method={attr_method} not found. Choose one of the following: {list(self.attribution.keys())}"
        attribution = self.attribution[attr_method]
        for idx, token in enumerate(self.tokens[0]):
            token_id = self.token_ids[0][idx]
            raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
            clean_token = self.tokenizer.decode(token_id)
            # Strip prefixes because bert decode still has ## for partials even after decode()
            clean_token = strip_tokenizer_prefix(self.config, clean_token)

            type = "input" if idx < self.n_input_tokens else 'output'
            if idx < len(attribution[importance_id]):
                imp = attribution[importance_id][idx]
            else:
                imp = 0

            tokens.append({
                'token':
                clean_token,
                'token_id':
                int(self.token_ids[0][idx]),
                'is_partial':
                is_partial_token(self.config, raw_token),
                'type':
                type,
                'value':
                str(imp
                    ),  # because json complains of floats. Probably not used?
                'position':
                idx
            })

        if len(ignore_tokens) > 0:
            for output_token_index, _ in enumerate(attribution):
                for idx in ignore_tokens:
                    attribution[output_token_index][idx] = 0

        data = {
            'tokens': tokens,
            'attributions': [att.tolist() for att in attribution]
        }

        d.display(
            d.HTML(filename=os.path.join(self._path, "html", "setup.html")))

        if (style == "minimal"):
            js = f"""
             requirejs(['basic', 'ecco'], function(basic, ecco){{
                const viz_id = basic.init()
                console.log(viz_id)
                // ecco.interactiveTokens(viz_id, {{}})
                window.ecco[viz_id] = new ecco.MinimalHighlighter({{
                    parentDiv: viz_id,
                    data: {json.dumps(data)},
                    preset: 'viridis',
                    tokenization_config: {json.dumps(self.config['tokenizer_config'])}

             }})

             window.ecco[viz_id].init();
             window.ecco[viz_id].selectFirstToken();

             }}, function (err) {{
                console.log(err);
            }})"""
        elif (style == "detailed"):

            js = f"""
             requirejs(['basic', 'ecco'], function(basic, ecco){{
                const viz_id = basic.init()
                console.log(viz_id)
                window.ecco[viz_id] = ecco.interactiveTokens({{
                    parentDiv: viz_id,
                    data: {json.dumps(data)},
                    tokenization_config: {json.dumps(self.config['tokenizer_config'])}
             }})

             }}, function (err) {{
                console.log(err);
            }})"""

        d.display(d.Javascript(js))

        if 'printJson' in kwargs and kwargs['printJson']:
            print(data)
            return data