def _promote_shapes(fun_name, *args): """Apply NumPy-style broadcasting, making args shape-compatible for lax.py.""" if len(args) < 2: return args else: shapes = [np.shape(arg) for arg in args] if all(len(shapes[0]) == len(s) for s in shapes[1:]): return args # no need for rank promotion, so rely on lax promotion nonscalar_ranks = {len(shp) for shp in shapes if shp} if len(nonscalar_ranks) < 2: return args else: if config.jax_numpy_rank_promotion != "allow": _rank_promotion_warning_or_error(fun_name, shapes) if config.jax_dynamic_shapes: # With dynamic shapes we don't support singleton-dimension broadcasting; # we instead broadcast out to the full shape as a temporary workaround. res_shape = lax.broadcast_shapes(*shapes) return [ _broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes) ] else: result_rank = len(lax.broadcast_shapes(*shapes)) return [ _broadcast_to(arg, (1, ) * (result_rank - len(shp)) + shp) for arg, shp in zip(args, shapes) ]
def _broadcast_arrays(*args): """Like Numpy's broadcast_arrays but doesn't return views.""" shapes = [np.shape(arg) for arg in args] if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes): # TODO(mattjj): remove the array(arg) here return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg) for arg in args] result_shape = lax.broadcast_shapes(*shapes) return [_broadcast_to(arg, result_shape) for arg in args]