def test_extract_batch_size():
    """Tests the behavior of extracting the batch size."""
    batch = "test string"
    assert extract_batch_size(batch) == 11

    batch = torch.zeros(11, 10, 9, 8)
    assert extract_batch_size(batch) == 11

    batch = {'test': torch.zeros(11, 10)}
    assert extract_batch_size(batch) == 11

    batch = [torch.zeros(11, 10)]
    assert extract_batch_size(batch) == 11

    batch = {'test': [{'test': [torch.zeros(11, 10)]}]}
    assert extract_batch_size(batch) == 11
Beispiel #2
0
 def extract_batch_size(self, batch: Any) -> int:
     try:
         batch_size = extract_batch_size(batch)
     except RecursionError:
         batch_size = 1
     self.batch_size = batch_size  # the setter converts it to `Tensor`
     return batch_size
Beispiel #3
0
 def _check_warning_raised(data, expected):
     with pytest.warns(
             UserWarning,
             match=
             f"Trying to infer the `batch_size` .* we found is {expected}."
     ):
         assert extract_batch_size(batch) == expected
     warning_cache.clear()
Beispiel #4
0
    def _extract_batch_size(
        self, value: Union[_ResultMetric, _ResultMetricCollection], batch_size: Optional[int], meta: _Metadata
    ) -> int:
        # check if we have extracted the batch size already
        if batch_size is None:
            batch_size = self.batch_size

        if batch_size is not None:
            return batch_size

        batch_size = 1
        is_tensor = value.is_tensor if isinstance(value, _ResultMetric) else value.has_tensor
        if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction:
            batch_size = extract_batch_size(self.batch)
            self.batch_size = batch_size

        return batch_size
Beispiel #5
0
def test_sample_metadata_field() -> None:
    """
    Test that the string constant we use to identify the metadata field is really matching the
    field name in SampleWithMetadata
    """
    batch_size = 5
    xyz = (6, 7, 8)
    shape = (batch_size, ) + xyz
    zero = torch.zeros(shape)
    s = Sample(metadata=DummyPatientMetadata,
               image=zero,
               mask=zero,
               labels=torch.zeros((batch_size, ) + (2, ) + xyz))
    fields = vars(s)
    assert len(fields) == 4
    assert SAMPLE_METADATA_FIELD in fields
    # Lightning attempts to determine the batch size by trying to find a tensor field in the sample.
    # This only works if any field other than Metadata is first.
    assert extract_batch_size(fields) == batch_size
 def on_train_batch_start(self,
                          batch: Any,
                          batch_idx: int,
                          dataloader_idx: int = 0) -> None:
     if not self._hivemind_initialized:
         self._hivemind_initialized = True
         # todo (sean): we could technically support a dynamic batch size by inferring each step
         # and passing it to the ``hivemind.Optimizer``.
         if self._batch_size is None:
             try:
                 self._batch_size = extract_batch_size(batch)
                 log.info(
                     f"Found per machine batch size automatically from the batch: {self._batch_size}"
                 )
             except (MisconfigurationException, RecursionError) as e:
                 raise MisconfigurationException(
                     "We tried to infer the batch size from the first batch of data. "
                     "Please provide the batch size to the Strategy by "
                     "``Trainer(strategy=HivemindStrategy(batch_size=x))``. "
                 ) from e
         self._initialize_hivemind()
Beispiel #7
0
 def _check_error_raised(data):
     with pytest.raises(MisconfigurationException,
                        match="We could not infer the batch_size"):
         extract_batch_size(batch)
Beispiel #8
0
 def _check_warning_not_raised(data, expected):
     with no_warning_call(match="Trying to infer the `batch_size`"):
         assert extract_batch_size(data) == expected
Beispiel #9
0
 def extract_batch_size(self, batch: Any) -> None:
     try:
         self.batch_size = extract_batch_size(batch)
     except RecursionError:
         self.batch_size = 1
 def _check_warning_not_raised(data, expected):
     with pytest.warns(None) as record:
         assert extract_batch_size(data) == expected
     assert len(record) == 0