示例#1
0
def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape,
                                  broadcast_dimensions):
    operand, = vals
    dim, = dims
    out_dim = broadcast_dimensions[dim]
    if shape[out_dim] != shape[dim]:
        raise ValueError(
            "broadcast_in_dim changes hidden dimension size: {} to {}".format(
                shape[dim], shape[out_dim]))
    sub_bdims = tuple(onp.delete(broadcast_dimensions, dim))
    sub_shape = tuple(onp.delete(shape, out_dim))
    return lax.broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim
示例#2
0
def _cond_pred_bcast_select(pred, x, y):
  bcast_pred = lax.broadcast_in_dim(pred, onp.shape(x), list(range(onp.ndim(pred))))
  return lax.select(bcast_pred, x, y)