예제 #1
0
 def test_t5_tokenizer(self):
     model_name = 't5-small'
     config = util.load_config(model_name)
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     token_ids = tokenizer(' tokenization')['input_ids']
     is_partial_1 = util.is_partial_token(
         config, tokenizer.convert_ids_to_tokens(token_ids[0]))
     is_partial_2 = util.is_partial_token(
         config, tokenizer.convert_ids_to_tokens(token_ids[1]))
     assert not is_partial_1
     assert is_partial_2
예제 #2
0
 def test_gpt_tokenizer(self):
     tokenizers = ['gpt2', 'bert-base-uncased']
     model_name = 'distilgpt2'
     config = util.load_config(model_name)
     tokenizer = AutoTokenizer.from_pretrained(tokenizers[0])
     token_ids = tokenizer(' tokenization')['input_ids']
     is_partial_1 = util.is_partial_token(
         config, tokenizer.convert_ids_to_tokens(token_ids[0]))
     is_partial_2 = util.is_partial_token(
         config, tokenizer.convert_ids_to_tokens(token_ids[1]))
     assert not is_partial_1
     assert is_partial_2
예제 #3
0
 def test_bert_tokenizer(self):
     model_name = 'bert-base-uncased'
     config = util.load_config(model_name)
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     token_ids = tokenizer(' tokenization')['input_ids']
     is_partial_1 = util.is_partial_token(config,
                                          tokenizer.convert_ids_to_tokens(
                                              token_ids[1]))  # skip CLS
     is_partial_2 = util.is_partial_token(
         config, tokenizer.convert_ids_to_tokens(token_ids[2]))
     assert not is_partial_1
     assert is_partial_2
예제 #4
0
    def display_token(self, viz_id, token_id, position):

        raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
        clean_token = self.tokenizer.decode(token_id)
        # Strip prefixes because bert decode still has ## for partials even after decode()
        clean_token = strip_tokenizer_prefix(self.model_config, clean_token)

        token = {
            # 'token': self.tokenizer.decode([token_id]),
            'token': clean_token,
            'is_partial': is_partial_token(self.model_config, raw_token),
            'token_id': int(token_id),
            'position': position,
            'type': 'output'
        }
        js = f"""
        // We don't really need these require scripts. But this is to avert
        //this code from running before display_input_sequence which DOES require external files
        requirejs(['basic', 'ecco'], function(basic, ecco){{
                console.log('addToken viz_id', '{viz_id}');
                window.ecco['{viz_id}'].addToken({json.dumps(token)})
                window.ecco['{viz_id}'].redraw()
        }})
        """
        # print(js)
        d.display(d.Javascript(js))
예제 #5
0
    def display_input_sequence(self, input_ids):

        tokens = []
        for idx, token_id in enumerate(input_ids):
            type = "input"
            raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
            clean_token = self.tokenizer.decode(token_id)
            # Strip prefixes because bert decode still has ## for partials even after decode()
            clean_token = strip_tokenizer_prefix(self.model_config, clean_token)
            tokens.append({
                # 'token': self.tokenizer.decode([token_id]),
                           'token': clean_token,
                           'is_partial': is_partial_token(self.model_config, raw_token),
                           'position': idx,
                           'token_id': int(token_id),
                           'type': type})
        data = {'tokens': tokens}

        d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html")))

        viz_id = f'viz_{round(random.random() * 1000000)}'

        # TODO: Stop passing tokenization_config to JS now that
        # it's handled with the is_partial parameter
        js = f"""
         requirejs( ['basic', 'ecco'], function(basic, ecco){{
            basic.init('{viz_id}') // Python needs to know the viz id. Used for each output token.
            window.ecco['{viz_id}'] = new ecco.renderOutputSequence({{
                    parentDiv: '{viz_id}',
                    data: {json.dumps(data)},
                    tokenization_config: {json.dumps(self.model_config['tokenizer_config'])}
            
            }})
         }}, function (err) {{
            console.log(err);
        }})
        """

        d.display(d.Javascript(js))
        return viz_id
예제 #6
0
    def primary_attributions(self,
                             attr_method: Optional[str] = 'grad_x_input',
                             style="minimal",
                             ignore_tokens: Optional[List[int]] = [],
                             **kwargs):
        """
            Explorable showing primary attributions of each token generation step.
            Hovering-over or tapping an output token imposes a saliency map on other tokens
            showing their importance as features to that prediction.

            Examples:

            ```python
            import ecco
            lm = ecco.from_pretrained('distilgpt2')
            text= "The countries of the European Union are:\n1. Austria\n2. Belgium\n3. Bulgaria\n4."
            output = lm.generate(text, generate=20, do_sample=True)

            # Show primary attributions explorable
            output.primary_attributions()
            ```

            Which creates the following interactive explorable:
            ![input saliency example 1](../../img/saliency_ex_1.png)

            If we want more details on the saliency values, we can use the detailed view:

            ```python
            # Show detailed explorable
            output.primary_attributions(style="detailed")
            ```

            Which creates the following interactive explorable:

            ![input saliency example 2 - detailed](../../img/saliency_ex_2.png)


            Details:
            This view shows the Gradient * Inputs method of input saliency. The attribution values are calculated across the
            embedding dimensions, then we use the L2 norm to calculate a score for each token (from the values of its embeddings dimension)
            To get a percentage value, we normalize the scores by dividing by the sum of the attribution scores for all
            the tokens in the sequence.
        """
        position = self.n_input_tokens

        importance_id = position - self.n_input_tokens
        tokens = []

        assert attr_method in self.attribution, \
            f"attr_method={attr_method} not found. Choose one of the following: {list(self.attribution.keys())}"
        attribution = self.attribution[attr_method]
        for idx, token in enumerate(self.tokens[0]):
            token_id = self.token_ids[0][idx]
            raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
            clean_token = self.tokenizer.decode(token_id)
            # Strip prefixes because bert decode still has ## for partials even after decode()
            clean_token = strip_tokenizer_prefix(self.config, clean_token)

            type = "input" if idx < self.n_input_tokens else 'output'
            if idx < len(attribution[importance_id]):
                imp = attribution[importance_id][idx]
            else:
                imp = 0

            tokens.append({
                'token':
                clean_token,
                'token_id':
                int(self.token_ids[0][idx]),
                'is_partial':
                is_partial_token(self.config, raw_token),
                'type':
                type,
                'value':
                str(imp
                    ),  # because json complains of floats. Probably not used?
                'position':
                idx
            })

        if len(ignore_tokens) > 0:
            for output_token_index, _ in enumerate(attribution):
                for idx in ignore_tokens:
                    attribution[output_token_index][idx] = 0

        data = {
            'tokens': tokens,
            'attributions': [att.tolist() for att in attribution]
        }

        d.display(
            d.HTML(filename=os.path.join(self._path, "html", "setup.html")))

        if (style == "minimal"):
            js = f"""
             requirejs(['basic', 'ecco'], function(basic, ecco){{
                const viz_id = basic.init()
                console.log(viz_id)
                // ecco.interactiveTokens(viz_id, {{}})
                window.ecco[viz_id] = new ecco.MinimalHighlighter({{
                    parentDiv: viz_id,
                    data: {json.dumps(data)},
                    preset: 'viridis',
                    tokenization_config: {json.dumps(self.config['tokenizer_config'])}

             }})

             window.ecco[viz_id].init();
             window.ecco[viz_id].selectFirstToken();

             }}, function (err) {{
                console.log(err);
            }})"""
        elif (style == "detailed"):

            js = f"""
             requirejs(['basic', 'ecco'], function(basic, ecco){{
                const viz_id = basic.init()
                console.log(viz_id)
                window.ecco[viz_id] = ecco.interactiveTokens({{
                    parentDiv: viz_id,
                    data: {json.dumps(data)},
                    tokenization_config: {json.dumps(self.config['tokenizer_config'])}
             }})

             }}, function (err) {{
                console.log(err);
            }})"""

        d.display(d.Javascript(js))

        if 'printJson' in kwargs and kwargs['printJson']:
            print(data)
            return data