Ejemplo n.º 1
0
 def test_broadcasting(self, fixture_shapes, fixture_exception_handling):
     shapes = fixture_shapes
     raise_exception = fixture_exception_handling
     try:
         expected_out = np.broadcast(*(np.empty(s) for s in shapes)).shape
     except ValueError:
         expected_out = None
     if expected_out is None:
         if raise_exception:
             with pytest.raises(ValueError):
                 shapes_broadcasting(*shapes,
                                     raise_exception=raise_exception)
         else:
             out = shapes_broadcasting(*shapes,
                                       raise_exception=raise_exception)
             assert out is None
     else:
         out = shapes_broadcasting(*shapes, raise_exception=raise_exception)
         assert out == expected_out
Ejemplo n.º 2
0
 def test_get_broadcastable_dist_samples(self, samples_to_broadcast):
     size, samples, broadcast_shape = samples_to_broadcast
     if broadcast_shape is not None:
         size_ = to_tuple(size)
         outs, out_shape = get_broadcastable_dist_samples(
             samples, size=size, return_out_shape=True)
         assert out_shape == broadcast_shape
         for i, o in zip(samples, outs):
             ishape = i.shape
             if ishape[:min([len(size_), len(ishape)])] == size_:
                 expected_shape = (size_ + (1, ) *
                                   (len(broadcast_shape) - len(ishape)) +
                                   ishape[len(size_):])
             else:
                 expected_shape = ishape
             assert o.shape == expected_shape
         assert shapes_broadcasting(*(o.shape
                                      for o in outs)) == broadcast_shape
     else:
         with pytest.raises(ValueError):
             get_broadcastable_dist_samples(samples, size=size)
Ejemplo n.º 3
0
 def test_type_check_success(self):
     inputs = [3, 3.0, tuple(), [3], (3, ), np.array(3), np.array([3])]
     out = shapes_broadcasting(*inputs)
     assert out == (3, )
Ejemplo n.º 4
0
 def test_type_check_raises(self, bad_input):
     with pytest.raises(TypeError):
         shapes_broadcasting(bad_input, tuple(), raise_exception=True)
     with pytest.raises(TypeError):
         shapes_broadcasting(bad_input, tuple(), raise_exception=False)