def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape, broadcast_dimensions): operand, = vals dim, = dims out_dim = broadcast_dimensions[dim] if shape[out_dim] != shape[dim]: raise ValueError( "broadcast_in_dim changes hidden dimension size: {} to {}".format( shape[dim], shape[out_dim])) sub_bdims = tuple(onp.delete(broadcast_dimensions, dim)) sub_shape = tuple(onp.delete(shape, out_dim)) return lax.broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim
def _cond_pred_bcast_select(pred, x, y): bcast_pred = lax.broadcast_in_dim(pred, onp.shape(x), list(range(onp.ndim(pred)))) return lax.select(bcast_pred, x, y)