def test_diff(self):
     dtype: torch.dtype = torch.float
     size = (2, 3, 4)
     arr = torch.arange(1, reduce(lambda x, y: x * y, size) + 1, dtype=dtype).view(size)
     arr2 = torch.arange(1, 11)
     size = (3, 2)
     arr3 = torch.arange(1, reduce(lambda x, y: x * y, size) + 1, dtype=dtype).view(size)
     arrs = (arr, arr2, arr3)
     enc = self.de.encode(arrs)
     decs = self.de.decode(enc)
     decs[2][1][1] = 1.11
     for enc, dec, tf in zip(arrs, decs, [True, True, False]):
         if tf:
             self.assertTrue(TorchConfig.equal(enc, dec))
         else:
             self.assertFalse(TorchConfig.equal(enc, dec))
Beispiel #2
0
 def assertTensorEquals(self, should, tensor):
     self.assertEqual(should.shape, tensor.shape)
     try:
         eq = TorchConfig.equal(should, tensor)
     except RuntimeError as e:
         logger.error(f'error comparing {should} with {tensor}')
         raise e
     if not eq:
         logger.error(f'tensor {should} does not equal {tensor}')
     self.assertTrue(eq)
Beispiel #3
0
 def test_datasets(self):
     tc = TorchConfig(False)
     fac = self.fac
     stash = fac('dataloader_stash')
     dataset = fac('mnist_batch_stash')
     dataset.delegate_attr = True
     ds_name = 'train val test'.split()
     batch_size = dataset.delegate.batch_size
     name: str
     ds: Tuple[Tuple[torch.Tensor, torch.Tensor]]
     for name, ds in zip(ds_name, stash.get_data_by_split()):
         ds_start = 0
         ds_stash = dataset.splits[name]
         ds_data = torch.cat(tuple(map(lambda x: x[0], ds)))
         ds_labels = torch.cat(tuple(map(lambda x: x[1], ds)))
         dpts = sum(map(lambda b: len(b.data_point_ids), ds_stash.values()))
         logger.info(f'name: stash size: {len(ds_stash)}, ' +
                     f'data set size: {len(ds)}, ' +
                     f'stash X batch_size: {len(ds_stash) * batch_size}, ' +
                     f'data/label shapes: {ds_data.shape}/{ds_labels.shape}, ' +
                     f'data points: {dpts}')
         assert len(ds) == len(ds_stash)
         assert dpts == ds_labels.shape[0]
         assert ds_labels.shape[0] == ds_data.shape[0]
         for id, batch in ds_stash:
             ds_end = ds_start + len(batch)
             dsb_labels = ds_labels[ds_start:ds_end]
             dsb_data = ds_data[ds_start:ds_end]
             ds_start = ds_end
             blabels = batch.get_labels()
             bdata = batch.get_data()
             if logger.isEnabledFor(logging.DEBUG):
                 logger.debug(f'data point ids: {batch.data_point_ids}')
                 logger.debug(f'ds/batch labels: {dsb_labels}/{blabels}')
             assert (tc.equal(dsb_labels, blabels))
             assert (tc.equal(dsb_data, bdata))
 def _trans_test(self, arrs: Sequence[Tensor]):
     enc = self.de.encode(arrs)
     decs = self.de.decode(enc)
     for enc, dec in zip(arrs, decs):
         self.assertTrue(TorchConfig.equal(enc, dec))