Exemple #1
0
def pad_trajectories(trajectories, boundary=10):
    """Pad trajectories to a bucket length that is a multiple of boundary."""

    # trajectories is a list of tuples of (observations, actions, rewards)
    # observations's length is one more than actions and rewards
    #
    # i.e. observations = (o_0, o_1, ... o_{T-1}, o_T)
    #           actions = (a_0, a_1, ... a_{T-1})
    #           rewards = (r_0, r_1, ... r_{T-1})

    # Given the above, let's compute max(T) over all trajectories.
    t_max = max(o.shape[0] for (o, a, r) in trajectories)

    # t_max - 1 is rounded to the next multiple of `boundary`
    boundary = int(boundary)
    bucket_length = boundary * int(np.ceil(float(t_max - 1) / boundary))

    # So all obs will be padded to t_max and actions and rewards to t_max - 1.
    padded_observations = []
    padded_actions = []
    padded_rewards = []
    padded_lengths = []
    reward_masks = []
    for (o, a, r) in trajectories:
        # Determine the amount to pad, this holds true for obs, actions and rewards.
        num_to_pad = bucket_length + 1 - o.shape[0]
        padded_lengths.append(num_to_pad)
        if num_to_pad == 0:
            padded_observations.append(o)
            padded_actions.append(a)
            padded_rewards.append(r)
            reward_masks.append(onp.ones_like(r, dtype=np.int32))
            continue

        # First pad observations.
        padding_config = [(0, num_to_pad, 0)]
        for _ in range(o.ndim - 1):
            padding_config.append((0, 0, 0))
        padding_config = tuple(padding_config)
        padding_value = 0.0 if o.dtype == np.float32 else 0
        padded_obs = lax.pad(o, padding_value, padding_config)
        padded_observations.append(padded_obs)

        # Now pad actions and rewards.
        assert a.ndim == 1 and r.ndim == 1
        padding_config = ((0, num_to_pad, 0), )
        action_padding_value = 0.0 if a.dtype == np.float32 else 0
        reward_padding_value = 0.0 if r.dtype == np.float32 else 0
        padded_action = lax.pad(a, action_padding_value, padding_config)
        padded_actions.append(padded_action)
        padded_reward = lax.pad(r, reward_padding_value, padding_config)
        padded_rewards.append(padded_reward)

        # Also create the mask to use later.
        reward_mask = onp.ones_like(r, dtype=np.int32)
        reward_masks.append(lax.pad(reward_mask, 0, padding_config))

    return padded_lengths, np.stack(reward_masks), np.stack(
        padded_observations), np.stack(padded_actions), np.stack(
            padded_rewards)
Exemple #2
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    if a_dot is ad_util.zero:
        return (core.pack(
            (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero)))

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    permutation = lu_pivots_to_permutation(pivots, m)
    batch_dims = a_shape[:-2]
    iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots), (lu_dot, ad_util.zero)
Exemple #3
0
  def testPadGrad(self, shape, dtype, pads):
    rng = jtu.rand_small(self.rng())
    operand = rng(shape, dtype)
    pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
    check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)

    operand = rng(shape, dtype)
    padding_value = np.array(0., dtype)
    pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
    check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)
Exemple #4
0
  def testPadGrad(self, shape, dtype, pads, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
    check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)

    operand = rng(shape, dtype)
    padding_value = onp.array(0., dtype)
    pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
    check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)
Exemple #5
0
    def testPad(self):
        R = onp.random.RandomState(0).randn

        fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)])
        x = R(5, 10).astype(onp.float32)
        ans = vmap(fun)(x)
        expected_ans = np.stack(list(map(fun, x)))
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)])
        x = R(5, 10, 3).astype(onp.float32)
        ans = vmap(fun)(x)
        expected_ans = np.stack(list(map(fun, x)))
        self.assertAllClose(ans, expected_ans, check_dtypes=False)
Exemple #6
0
def lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax._dtype(a)
    k = min(m, n)

    # TODO(phawkins): use a gather rather than a matrix multiplication here.
    permutation = lu_pivots_to_permutation(pivots, m)
    p = np.array(permutation[:, None] == np.arange(m), dtype=dtype)
    x = np.matmul(p, a_dot)

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
Exemple #7
0
def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes,
           fill_value=0):
  """Similar to lax.dynamic_slice, but handles arrays with dynamic sizes.

  Returns fill_value instead of clamping start_indices for those elements that
  would overflow the side of the array.

  Args:
    operand: the array to slice
    start_indices: the offset of the start of the slice
    dynamic_slice_sizes: the true (unpadded) size of the slice
    static_slice_sizes: the padded size of the slice, which must be known at
      compile time. The static size must be larger than the dynamic size.
    fill_value: value with which to replace masked-out elements.
  Returns:
    An array with static shape `static_slice_sizes`, padded from its true
    (dynamic) size `dynamic_slice_sizes`.
  """
  # We must pad the input array so the dynamic_slice is guaranteed to fall
  # entirely in bounds.
  padded = lax.pad(operand,
                   jnp.array(0, operand.dtype),
                   [(0, d, 0) for d in static_slice_sizes])
  out = lax.dynamic_slice(padded, tuple(jnp.int32(i) for i in start_indices),
                          static_slice_sizes)
  return _mask(out, dynamic_slice_sizes, fill_value)
Exemple #8
0
 def outslice(inslice, padding_value):
     assert inslice is None or np.array_equal(inslice.shape, inshape)
     return (lax.pad(
         inslice, padding_value,
         zip(offset,
             np.array(outcount.shape) - limit, interior)) if insize else
             jnp.full(outcount.shape, padding_value, operand.dtype))
Exemple #9
0
def dctn(x, type=2, s=None, axes=None, norm=None):
    if type != 2:
        raise NotImplementedError('Only DCT type 2 is implemented.')

    if axes is None:
        axes = range(x.ndim)

    if len(axes) == 1:
        return dct(x,
                   n=s[0] if s is not None else None,
                   axis=axes[0],
                   norm=norm)

    if s is not None:
        ns = {a: n for a, n in zip(axes, s)}
        pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0)
                for a in range(x.ndim)]
        x = lax.pad(x, jnp.array(0, x.dtype), pads)

    if len(axes) == 2:
        return _dct2(x, axes=axes, norm=norm)

    # compose high-D DCTs from 2D and 1D DCTs:
    for axes_block in [axes[i:i + 2] for i in range(0, len(axes), 2)]:
        x = dctn(x, axes=axes_block, norm=norm)
    return x
Exemple #10
0
def spatial_pad(pad_vertical, pad_horizontal, operand):
    """
  Wrapper around lax.pad which pads spatial dimensions (horizontal and vertical)
  with zeros, without any interior padding.
  """
    zero = (0, 0, 0)
    return lax.pad(operand, 0.,
                   (zero, pad_vertical + (0, ), pad_horizontal + (0, ), zero))
Exemple #11
0
def block_diag(*arrs):
  if len(arrs) == 0:
    arrs = [jnp.zeros((1, 0))]
  arrs = jnp._promote_dtypes(*arrs)
  bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
  if bad_shapes:
    raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
                     "most 2 dimensions, got {} at argument {}."
                     .format(arrs[bad_shapes[0]], bad_shapes[0]))
  arrs = [jnp.atleast_2d(a) for a in arrs]
  acc = arrs[0]
  dtype = lax.dtype(acc)
  for a in arrs[1:]:
    _, c = a.shape
    a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
    acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
    acc = lax.concatenate([acc, a], dimension=0)
  return acc
Exemple #12
0
def _pad_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray],
           cts_in: ShapedArray) -> np.ndarray:
    padding_config = eqn.params['padding_config']

    inval = invals[idx]
    j = np.eye(inval.size, dtype=inval.dtype)
    j = j.reshape(inval.shape * 2)
    for _ in range(inval.ndim):
        padding_config += ((0, 0, 0), )

    j = lax.pad(j, np.zeros((), j.dtype), padding_config)
    return j
Exemple #13
0
def dct(x, type=2, n=None, axis=-1, norm=None):
    if type != 2:
        raise NotImplementedError('Only DCT type 2 is implemented.')

    axis = _canonicalize_axis(axis, x.ndim)
    if n is not None:
        x = lax.pad(x, jnp.array(0, x.dtype),
                    [(0, n - x.shape[axis] if a == axis else 0, 0)
                     for a in range(x.ndim)])

    N = x.shape[axis]
    v = _dct_interleave(x, axis)
    V = jnp.fft.fft(v, axis=axis)
    k = lax.expand_dims(jnp.arange(N), [a for a in range(x.ndim) if a != axis])
    out = V * _W4(N, k)
    out = 2 * out.real
    if norm == 'ortho':
        out = _dct_ortho_norm(out, axis)
    return out
Exemple #14
0
def _update_slice(operand, update, start_indices, update_dims):
    """
  Similar to lax.dynamic_update_slice, but handles padded updates where padding
  values should not overwrite existing values in the array.

  Args:
  operand: the array to update
  update: the padded array to write
  start_indices: the offset at which to write `update`.
  update_dims: the true dimensions of the padded update `update`. Only values
    inside the rectangle given by `update_dims` will be overwritten."""
    operand_shape = operand.shape
    operand = lax.pad(operand, jnp.array(0, operand.dtype),
                      [(0, d, 0) for d in update.shape])
    start_indices = tuple(jnp.int32(i) for i in start_indices)
    t = lax.dynamic_slice(operand, start_indices, update.shape)
    t = _mask(update, update_dims, t)
    operand = lax.dynamic_update_slice(operand, t, start_indices)
    return lax.slice(operand, [0] * operand.ndim, operand_shape)
def _get_f_and_eqn(params, primitive, *inputs):
    if primitive is None:
        f = lambda x: x
        eqn = None

    else:
        if primitive is lax.pad_p:
            # TODO(romann): find a way to call primitive.bind directly.
            f = lambda *inputs: lax.pad(*inputs, **params)

        elif primitive is lax.conv_general_dilated_p:
            # TODO(romann): find a way to call primitive.bind directly.
            f = lambda *inputs: lax.conv_general_dilated(*inputs, **params)

        else:
            f = lambda *inputs: primitive.bind(*inputs, **params)

        eqn = jax.make_jaxpr(f)(*inputs).eqns[0]

    return eqn, f
Exemple #16
0
def onnx_conv(
    x,
    w,
    b=None,
    group=1,
    kernel_shape=None,
    pads=None,
    strides=None,
    dilations=None,
    auto_pad=None,
):
    kernel_shape = kernel_shape or w.shape
    spatial_size = w.ndim - 2
    strides = strides or [1] * spatial_size

    # TODO some pad does not need a PadOp
    if not auto_pad or auto_pad == "NOTSET":
        if pads is not None:
            x = lax.pad(x, 0.0, pads)
        pad_mode = "VALID"
    elif auto_pad == "SAME_UPPER":
        pad_mode = "SAME"
    elif auto_pad == "VALID":
        pad_mode = "VALID"
    elif auto_pad == "SAME_LOWER":
        raise NotImplemented("Conv with auto_pad `SAME_LOWER`")
    else:
        raise ValueError("Invalid auto_pad attribute: {}".format(auto_pad))

    lhs_dilation = [1] * (w.ndim - 2)
    rhs_dilation = dilations or [1] * (w.ndim - 2)

    if b is not None:
        b = b.reshape([1, w.shape[0]] + [1] * spatial_size)
    else:
        b = 0

    out = lax.conv_general_dilated(x, w, strides, pad_mode, lhs_dilation,
                                   rhs_dilation, None, group, 1)
    return out + b
Exemple #17
0
 def fun(x):
     return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)])
Exemple #18
0
 def p(x):
     return lax.pad(x, np.array(0., x.dtype), [(1, 1, 1)])
Exemple #19
0
 def f_jax(x):
     return lax.pad(x, np.float_(5.), ((0, 0, 0), (0, 0, 0), (1, 1, 1)))
Exemple #20
0
 def p(x):
     return lax.pad(x, 0, [(1, 1, 1)])
Exemple #21
0
def _pad_in_dim(x, low=0, high=0, interior=0, fill_value=0, axis=0):
    pads = [(0, 0, 0)] * x.ndim
    pads[axis] = (low, high, interior)
    return lax.pad(x, jnp.array(fill_value, x.dtype), pads)
Exemple #22
0
 def testPad(self, shape, dtype, pads, bdims):
     rng = jtu.rand_small(self.rng())
     fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
     self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
Exemple #23
0
 def fun2(x):
     # Can be converted only if enable_xla is on, due to negative padding.
     return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)])
Exemple #24
0
def pad_trajectories(trajectories, boundary=20):
    """Pad trajectories to a bucket length that is a multiple of boundary.

  Args:
    trajectories: list[(observation, actions, rewards)], where each observation
      is shaped (t+1,) + OBS and actions & rewards are shaped (t,), with the
      length of the list being B (batch size).
    boundary: int, bucket length, the actions and rewards are padded to integer
      multiples of boundary.

  Returns:
    tuple: (padding lengths, reward_mask, padded_observations, padded_actions,
        padded_rewards) where padded_observations is shaped (B, RT+1) + OBS and
        padded_actions, padded_rewards & reward_mask are shaped (B, RT).
        Where RT is max(t) rounded up to an integer multiple of boundary.
        padded_length is how much padding we've added and
        reward_mask is 1s for actual rewards and 0s for the padding.
  """

    # Let's compute max(t) over all trajectories.
    t_max = max(r.shape[0] for (_, _, r, _) in trajectories)

    # t_max is rounded to the next multiple of `boundary`
    boundary = int(boundary)
    bucket_length = boundary * int(np.ceil(float(t_max) / boundary))

    # So all obs will be padded to t_max + 1 and actions and rewards to t_max.
    padded_observations = []
    padded_actions = []
    padded_rewards = []
    padded_infos = collections.defaultdict(list)
    padded_lengths = []
    reward_masks = []

    for (o, a, r, i) in trajectories:
        # Determine the amount to pad, this holds true for obs, actions and rewards.
        num_to_pad = bucket_length + 1 - o.shape[0]
        padded_lengths.append(num_to_pad)
        if num_to_pad == 0:
            padded_observations.append(o)
            padded_actions.append(a)
            padded_rewards.append(r)
            reward_masks.append(onp.ones_like(r, dtype=np.int32))
            if i:
                for k, v in i.items():
                    padded_infos[k].append(v)
            continue

        # First pad observations.
        padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] *
                               (o.ndim - 1))

        padding_value = get_padding_value(o.dtype)
        action_padding_value = get_padding_value(a.dtype)
        reward_padding_value = get_padding_value(r.dtype)

        padded_obs = lax.pad(o, padding_value, padding_config)
        padded_observations.append(padded_obs)

        # Now pad actions and rewards.
        padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] *
                               (a.ndim - 1))
        padded_action = lax.pad(a, action_padding_value, padding_config)
        padded_actions.append(padded_action)

        assert r.ndim == 1
        padding_config = ((0, num_to_pad, 0), )
        padded_reward = lax.pad(r, reward_padding_value, padding_config)
        padded_rewards.append(padded_reward)

        # Also create the mask to use later.
        reward_mask = onp.ones_like(r, dtype=np.int32)
        reward_masks.append(lax.pad(reward_mask, 0, padding_config))

        if i:
            for k, v in i.items():
                # Create a padding configuration for this value.
                padding_config = [(0, num_to_pad, 0)
                                  ] + [(0, 0, 0)] * (v.ndim - 1)
                padded_infos[k].append(lax.pad(v, 0.0, tuple(padding_config)))

    # Now stack these padded_infos if they exist.
    stacked_padded_infos = None
    if padded_infos:
        stacked_padded_infos = {
            k: np.stack(v)
            for k, v in padded_infos.items()
        }

    return padded_lengths, np.stack(reward_masks), np.stack(
        padded_observations), np.stack(padded_actions), np.stack(
            padded_rewards), stacked_padded_infos
Exemple #25
0
def test_pad(shape, dtype, padding_config, rng_factory):
    rng = rng_factory(np.random)
    args = [rng(shape, dtype), rng((), dtype)]
    op = lambda *args: lax.pad(*args, padding_config)
    tu.check_lazy_fun(op, *args)
Exemple #26
0
 def pad(x):
   return lax.pad(x, jnp.array(1., x.dtype), padding_config)
Exemple #27
0
 def testPad(self, shape, dtype, pads, bdims, rng_factory):
     rng = rng_factory(self.rng())
     fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
     self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)