示例#1
0
def add_batched(batched_args, batch_dims):
  bdx, bdy = batch_dims
  x, y = batched_args
  if bdx == bdy:
    return add_jaxvals(x, y), bdx
  elif bdx is not_mapped:
    x = broadcast(x, y.shape[bdy], bdy)
    return add_jaxvals(x, y), bdy
  elif bdy is not_mapped:
    y = broadcast(y, x.shape[bdx], bdx)
    return add_jaxvals(x, y), bdx
  else:
    x = moveaxis(x, bdx, bdy)
    return add_jaxvals(x, y), bdy
示例#2
0
文件: batching.py 项目: wayfeng/jax
def add_batched(batched_args, batch_dims):
    bdx, bdy = batch_dims
    x, y = batched_args
    if bdx == bdy or core.get_aval(x) == core.abstract_unit:
        return add_jaxvals(x, y), bdx
    elif bdx is not_mapped:
        x = broadcast(x, y.shape[bdy], bdy)
        return add_jaxvals(x, y), bdy
    elif bdy is not_mapped:
        y = broadcast(y, x.shape[bdx], bdx)
        return add_jaxvals(x, y), bdx
    else:
        x = moveaxis(x, bdx, bdy)
        return add_jaxvals(x, y), bdy
示例#3
0
文件: ad.py 项目: jbampton/jax
def add_tangents(x, y):
  if type(x) is Zero:
    return y
  elif type(y) is Zero:
    return x
  else:
    return add_jaxvals(x, y)