Example #1
0
    def make_activations(self, fields, lengths) -> Activations:
        """Create Activations from the output tuples produced by PyTorch Transformers.
        Includes converting torch tensors to xp, and handling missing values.
        """
        fields = list(fields)
        fields[0] = torch2xp(fields[0])
        fields[0] = RaggedArray.from_padded(fields[0], lengths)
        assert fields[0].data.shape[0] == sum(lengths)
        # lh: last hidden
        # po: pooler_output
        # ah: all_hidden
        # aa: all_attention
        if len(fields) != 4:
            lh = fields[0]
            po = RaggedArray.blank()
        else:
            if isinstance(fields[1], tuple):
                fields[1] = RaggedArray.blank()
            else:
                fields[1] = RaggedArray(torch2xp(fields[1]), [1] * len(lengths))

            fields[2] = [RaggedArray.from_padded(torch2xp(fields[2][i]), lengths) for i in range(len(fields[2]))]

            lh, po, ah, aa = fields

        # Convert last_hidden_state to xp
        return Activations(lh, po, ah, aa)
Example #2
0
 def from_pytt(cls, fields, *, is_grad=False) -> "Activations":
     """Create Activations from the output tuples produced by PyTorch Transformers.
     Includes converting torch tensors to xp, and handling missing values.
     """
     # lh: last hidden
     # po: pooler_output
     # ah: all_hidden
     # aa: all_attention
     if len(fields) != 4:
         lh = fields[0]
         po = tuple()
         ah = []
         aa = []
     else:
         lh, po, ah, aa = fields
     # Convert last_hidden_state to xp
     lh = torch2xp(lh)
     xp = get_array_module(lh)
     # Normalize "None" value for pooler output
     if isinstance(po, tuple):
         po = xp.zeros((0, ), dtype=lh.dtype)
     else:
         po = torch2xp(po)
     ah = list(map(torch2xp, ah))
     aa = list(map(torch2xp, aa))
     return cls(lh, po, ah, aa, is_grad=is_grad)
 def _update_pytorch_averages(self, sgd, *, init_steps=1):
     if sgd.averages is None:
         return
     # Collect parameters if we don't have them
     for name, param in self._model.state_dict().items():
         key = f"pytorch_{self.id}_{name}"
         sgd.nr_update[key] += 1
         xp_param = torch2xp(param)
         if key in sgd.averages:
             self.ops.update_averages(sgd.averages[key], xp_param,
                                      sgd.nr_update[key])
         else:
             sgd.averages[key] = xp_param.copy()
             sgd.nr_update[key] = init_steps
 def from_pytt(cls, fields, *, is_grad=False) -> "Activations":
     """Create Activations from the output tuples produced by PyTorch Transformers.
     Includes converting torch tensors to xp, and handling missing values.
     """
     fields = list(fields)
     # Make sure we have 4 elements
     while len(fields) < 4:
         fields.append([])
     # Normalize None to []
     fields = [f if f is not None else f for f in fields]
     # lh: last hidden
     # po: pooler_output
     # ah: all_hidden
     # aa: all_attention
     lh, po, ah, aa = fields
     # Convert last_hidden_state to xp
     lh = torch2xp(lh)
     # Normalize "None" value for pooler output
     if isinstance(po, tuple) and all(x is None for x in po):
         po = []
     po = list(map(torch2xp, po))
     ah = list(map(torch2xp, ah))
     aa = list(map(torch2xp, aa))
     return cls(lh, po, ah, aa, is_grad=is_grad)