def size_of_tensor(obj: torch.Tensor) -> int: r"""**Get size in bytes of Tensor, torch.nn.Module or standard object.** Specific routines are defined for torch.tensor objects and torch.nn.Module objects. They will calculate how much memory in bytes those object consume. If another object is passed, `sys.getsizeof` will be called on it. This function works similarly to C++'s sizeof operator. Parameters ---------- obj Object whose size will be measured. Returns ------- int Size in bytes of the object""" if torch.is_tensor(obj): return obj.element_size() * obj.numel() elif isinstance(obj, torch.nn.Module): return sum( size_of_tensor(tensor) for tensor in itertools.chain(obj.buffers(), obj.parameters())) else: return sys.getsizeof(obj)