Beispiel #1
0
 def _first_call_assign_qtensor_infos_to_mod_outputs_tensor(
     self,
     output: torch.Tensor,
     qtensor_id: List[int],
 ) -> torch.Tensor:
     """
     This is a helper function for _first_call_assign_qtensor_infos_to_mod_outputs
     to handle iterables of tensors without code duplication.
     """
     if not hasattr(output, '_qtensor_info'):
         # TODO: use actual dtype instead of defaulting to float
         output._qtensor_info = QTensorInfo(qtensor_id[0], torch.float)  # type: ignore[attr-defined]
         qtensor_id[0] += 1
     self.output_qtensor_infos.append(output._qtensor_info)  # type: ignore[attr-defined]
     # TODO(future PR): add an observer if needed
     return output