def test_shape(self, input_param, input, expected_shape): result = AddCoordinateChannelsd(**input_param)(input)["img"] input = input["img"] self.assertEqual(type(result), type(input)) if isinstance(result, torch.Tensor): self.assertEqual(result.device, input.device) self.assertEqual(result.shape, expected_shape) assert_allclose(input[0, ...], result[0, ...])
def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input)
def test_shape(self, input_param, input, expected_shape): result = AddCoordinateChannelsd(**input_param)(input) self.assertEqual(list(result["img"].shape), list(expected_shape)) np.testing.assert_array_equal(input["img"][0, ...], result["img"][0, ...])