def check_tensordot_mismatch_simple_error(_np_a, _np_b, axes): _ns_a = nps.array(_np_a) _ns_b = nps.array(_np_b) with pytest.raises(ValueError): np.tensordot(_np_a, _np_b, axes=axes) with pytest.raises(ValueError): nps.tensordot(_ns_a, _ns_b, axes=axes)
def check_tensordot_axes_type_error(_np_a, _np_b, axes): _ns_a = nps.array(_np_a) _ns_b = nps.array(_np_b) # TODO (bcp): Remove test once tensordot over multiple axes is implemented. if is_array_like(axes): with pytest.raises(NotImplementedError): nps.tensordot(_ns_a, _ns_b, axes=axes) else: with pytest.raises(TypeError): np.tensordot(_np_a, _np_b, axes=axes) with pytest.raises(TypeError): nps.tensordot(_ns_a, _ns_b, axes=axes)