def test_broadcasting(self, fixture_shapes, fixture_exception_handling): shapes = fixture_shapes raise_exception = fixture_exception_handling try: expected_out = np.broadcast(*(np.empty(s) for s in shapes)).shape except ValueError: expected_out = None if expected_out is None: if raise_exception: with pytest.raises(ValueError): shapes_broadcasting(*shapes, raise_exception=raise_exception) else: out = shapes_broadcasting(*shapes, raise_exception=raise_exception) assert out is None else: out = shapes_broadcasting(*shapes, raise_exception=raise_exception) assert out == expected_out
def test_get_broadcastable_dist_samples(self, samples_to_broadcast): size, samples, broadcast_shape = samples_to_broadcast if broadcast_shape is not None: size_ = to_tuple(size) outs, out_shape = get_broadcastable_dist_samples( samples, size=size, return_out_shape=True) assert out_shape == broadcast_shape for i, o in zip(samples, outs): ishape = i.shape if ishape[:min([len(size_), len(ishape)])] == size_: expected_shape = (size_ + (1, ) * (len(broadcast_shape) - len(ishape)) + ishape[len(size_):]) else: expected_shape = ishape assert o.shape == expected_shape assert shapes_broadcasting(*(o.shape for o in outs)) == broadcast_shape else: with pytest.raises(ValueError): get_broadcastable_dist_samples(samples, size=size)
def test_type_check_success(self): inputs = [3, 3.0, tuple(), [3], (3, ), np.array(3), np.array([3])] out = shapes_broadcasting(*inputs) assert out == (3, )
def test_type_check_raises(self, bad_input): with pytest.raises(TypeError): shapes_broadcasting(bad_input, tuple(), raise_exception=True) with pytest.raises(TypeError): shapes_broadcasting(bad_input, tuple(), raise_exception=False)