def test_diff(self): dtype: torch.dtype = torch.float size = (2, 3, 4) arr = torch.arange(1, reduce(lambda x, y: x * y, size) + 1, dtype=dtype).view(size) arr2 = torch.arange(1, 11) size = (3, 2) arr3 = torch.arange(1, reduce(lambda x, y: x * y, size) + 1, dtype=dtype).view(size) arrs = (arr, arr2, arr3) enc = self.de.encode(arrs) decs = self.de.decode(enc) decs[2][1][1] = 1.11 for enc, dec, tf in zip(arrs, decs, [True, True, False]): if tf: self.assertTrue(TorchConfig.equal(enc, dec)) else: self.assertFalse(TorchConfig.equal(enc, dec))
def assertTensorEquals(self, should, tensor): self.assertEqual(should.shape, tensor.shape) try: eq = TorchConfig.equal(should, tensor) except RuntimeError as e: logger.error(f'error comparing {should} with {tensor}') raise e if not eq: logger.error(f'tensor {should} does not equal {tensor}') self.assertTrue(eq)
def test_datasets(self): tc = TorchConfig(False) fac = self.fac stash = fac('dataloader_stash') dataset = fac('mnist_batch_stash') dataset.delegate_attr = True ds_name = 'train val test'.split() batch_size = dataset.delegate.batch_size name: str ds: Tuple[Tuple[torch.Tensor, torch.Tensor]] for name, ds in zip(ds_name, stash.get_data_by_split()): ds_start = 0 ds_stash = dataset.splits[name] ds_data = torch.cat(tuple(map(lambda x: x[0], ds))) ds_labels = torch.cat(tuple(map(lambda x: x[1], ds))) dpts = sum(map(lambda b: len(b.data_point_ids), ds_stash.values())) logger.info(f'name: stash size: {len(ds_stash)}, ' + f'data set size: {len(ds)}, ' + f'stash X batch_size: {len(ds_stash) * batch_size}, ' + f'data/label shapes: {ds_data.shape}/{ds_labels.shape}, ' + f'data points: {dpts}') assert len(ds) == len(ds_stash) assert dpts == ds_labels.shape[0] assert ds_labels.shape[0] == ds_data.shape[0] for id, batch in ds_stash: ds_end = ds_start + len(batch) dsb_labels = ds_labels[ds_start:ds_end] dsb_data = ds_data[ds_start:ds_end] ds_start = ds_end blabels = batch.get_labels() bdata = batch.get_data() if logger.isEnabledFor(logging.DEBUG): logger.debug(f'data point ids: {batch.data_point_ids}') logger.debug(f'ds/batch labels: {dsb_labels}/{blabels}') assert (tc.equal(dsb_labels, blabels)) assert (tc.equal(dsb_data, bdata))
def _trans_test(self, arrs: Sequence[Tensor]): enc = self.de.encode(arrs) decs = self.de.decode(enc) for enc, dec in zip(arrs, decs): self.assertTrue(TorchConfig.equal(enc, dec))