def test_extract_batch_size(): """Tests the behavior of extracting the batch size.""" batch = "test string" assert extract_batch_size(batch) == 11 batch = torch.zeros(11, 10, 9, 8) assert extract_batch_size(batch) == 11 batch = {'test': torch.zeros(11, 10)} assert extract_batch_size(batch) == 11 batch = [torch.zeros(11, 10)] assert extract_batch_size(batch) == 11 batch = {'test': [{'test': [torch.zeros(11, 10)]}]} assert extract_batch_size(batch) == 11
def extract_batch_size(self, batch: Any) -> int: try: batch_size = extract_batch_size(batch) except RecursionError: batch_size = 1 self.batch_size = batch_size # the setter converts it to `Tensor` return batch_size
def _check_warning_raised(data, expected): with pytest.warns( UserWarning, match= f"Trying to infer the `batch_size` .* we found is {expected}." ): assert extract_batch_size(batch) == expected warning_cache.clear()
def _extract_batch_size( self, value: Union[_ResultMetric, _ResultMetricCollection], batch_size: Optional[int], meta: _Metadata ) -> int: # check if we have extracted the batch size already if batch_size is None: batch_size = self.batch_size if batch_size is not None: return batch_size batch_size = 1 is_tensor = value.is_tensor if isinstance(value, _ResultMetric) else value.has_tensor if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction: batch_size = extract_batch_size(self.batch) self.batch_size = batch_size return batch_size
def test_sample_metadata_field() -> None: """ Test that the string constant we use to identify the metadata field is really matching the field name in SampleWithMetadata """ batch_size = 5 xyz = (6, 7, 8) shape = (batch_size, ) + xyz zero = torch.zeros(shape) s = Sample(metadata=DummyPatientMetadata, image=zero, mask=zero, labels=torch.zeros((batch_size, ) + (2, ) + xyz)) fields = vars(s) assert len(fields) == 4 assert SAMPLE_METADATA_FIELD in fields # Lightning attempts to determine the batch size by trying to find a tensor field in the sample. # This only works if any field other than Metadata is first. assert extract_batch_size(fields) == batch_size
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: if not self._hivemind_initialized: self._hivemind_initialized = True # todo (sean): we could technically support a dynamic batch size by inferring each step # and passing it to the ``hivemind.Optimizer``. if self._batch_size is None: try: self._batch_size = extract_batch_size(batch) log.info( f"Found per machine batch size automatically from the batch: {self._batch_size}" ) except (MisconfigurationException, RecursionError) as e: raise MisconfigurationException( "We tried to infer the batch size from the first batch of data. " "Please provide the batch size to the Strategy by " "``Trainer(strategy=HivemindStrategy(batch_size=x))``. " ) from e self._initialize_hivemind()
def _check_error_raised(data): with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"): extract_batch_size(batch)
def _check_warning_not_raised(data, expected): with no_warning_call(match="Trying to infer the `batch_size`"): assert extract_batch_size(data) == expected
def extract_batch_size(self, batch: Any) -> None: try: self.batch_size = extract_batch_size(batch) except RecursionError: self.batch_size = 1
def _check_warning_not_raised(data, expected): with pytest.warns(None) as record: assert extract_batch_size(data) == expected assert len(record) == 0