def make_activations(self, fields, lengths) -> Activations: """Create Activations from the output tuples produced by PyTorch Transformers. Includes converting torch tensors to xp, and handling missing values. """ fields = list(fields) fields[0] = torch2xp(fields[0]) fields[0] = RaggedArray.from_padded(fields[0], lengths) assert fields[0].data.shape[0] == sum(lengths) # lh: last hidden # po: pooler_output # ah: all_hidden # aa: all_attention if len(fields) != 4: lh = fields[0] po = RaggedArray.blank() else: if isinstance(fields[1], tuple): fields[1] = RaggedArray.blank() else: fields[1] = RaggedArray(torch2xp(fields[1]), [1] * len(lengths)) fields[2] = [RaggedArray.from_padded(torch2xp(fields[2][i]), lengths) for i in range(len(fields[2]))] lh, po, ah, aa = fields # Convert last_hidden_state to xp return Activations(lh, po, ah, aa)
def from_pytt(cls, fields, *, is_grad=False) -> "Activations": """Create Activations from the output tuples produced by PyTorch Transformers. Includes converting torch tensors to xp, and handling missing values. """ # lh: last hidden # po: pooler_output # ah: all_hidden # aa: all_attention if len(fields) != 4: lh = fields[0] po = tuple() ah = [] aa = [] else: lh, po, ah, aa = fields # Convert last_hidden_state to xp lh = torch2xp(lh) xp = get_array_module(lh) # Normalize "None" value for pooler output if isinstance(po, tuple): po = xp.zeros((0, ), dtype=lh.dtype) else: po = torch2xp(po) ah = list(map(torch2xp, ah)) aa = list(map(torch2xp, aa)) return cls(lh, po, ah, aa, is_grad=is_grad)
def _update_pytorch_averages(self, sgd, *, init_steps=1): if sgd.averages is None: return # Collect parameters if we don't have them for name, param in self._model.state_dict().items(): key = f"pytorch_{self.id}_{name}" sgd.nr_update[key] += 1 xp_param = torch2xp(param) if key in sgd.averages: self.ops.update_averages(sgd.averages[key], xp_param, sgd.nr_update[key]) else: sgd.averages[key] = xp_param.copy() sgd.nr_update[key] = init_steps
def from_pytt(cls, fields, *, is_grad=False) -> "Activations": """Create Activations from the output tuples produced by PyTorch Transformers. Includes converting torch tensors to xp, and handling missing values. """ fields = list(fields) # Make sure we have 4 elements while len(fields) < 4: fields.append([]) # Normalize None to [] fields = [f if f is not None else f for f in fields] # lh: last hidden # po: pooler_output # ah: all_hidden # aa: all_attention lh, po, ah, aa = fields # Convert last_hidden_state to xp lh = torch2xp(lh) # Normalize "None" value for pooler output if isinstance(po, tuple) and all(x is None for x in po): po = [] po = list(map(torch2xp, po)) ah = list(map(torch2xp, ah)) aa = list(map(torch2xp, aa)) return cls(lh, po, ah, aa, is_grad=is_grad)