Beispiel #1
0
def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions,
                         old_sizes):
    operand, = vals
    axis, = axes

    def filter_ones(xs):
        return filter(lambda x: x != 1, xs)

    def find_new_axis(old_axis, old_sizes, new_sizes):
        left = onp.prod(old_sizes[:old_axis])
        size = old_sizes[old_axis]
        prod = 1
        for i, cur_sz in enumerate(new_sizes):
            if prod == left and cur_sz == size:
                return i
            prod = prod * cur_sz
        return None

    if dimensions is None:
        new_axis = find_new_axis(axis, old_sizes, new_sizes)
        if new_axis is not None:
            new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:]
            return lax.reshape(operand, new_sizes_,
                               dimensions=dimensions), new_axis
        else:
            raise NotImplementedError(
                'papply of reshape that would change hidden dimension size')
    else:
        raise NotImplementedError('papply of reshape with `dimensions`')
Beispiel #2
0
def _update_arrays(i, aval, xs, x):
  assert isinstance(aval, core.AbstractValue)
  if isinstance(aval, core.AbstractTuple):
    return core.pack(map(partial(_update_arrays, i), aval, xs, x))
  else:
    x = lax.reshape(x, (1,) + onp.shape(x))
    return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
Beispiel #3
0
 def chooser_taylor_rule(primals_in, series_in, **params):
   operand, = primals_in
   gs, = series_in
   primal_out = chooser_fun(operand, **params)
   axes = params.pop("axes", None)
   primal_dtype = gs[0].dtype
   shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
   location_indicators = lax.convert_element_type(
         lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
   counts = lax._reduce_sum(location_indicators, axes)
   def _reduce_chooser_taylor_rule(g):
     return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
   series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
   return primal_out, series_out
Beispiel #4
0
def _conv_general_dilated_papply_rule(
    name, size, vals, dims, window_strides, padding, lhs_dilation, rhs_dilation,
    dimension_numbers, feature_group_count, precision, **unused_kwargs):
  lhs, rhs = vals
  lhs_dim, rhs_dim = dims
  lhs_spec_batch_dim = dimension_numbers.lhs_spec[0]
  if rhs_dim is None and lhs_dim == lhs_spec_batch_dim:
    lhs = lax.reshape(lhs, tuple(onp.insert(lhs.shape, lhs_dim, 1)))
    out = lax.conv_general_dilated(
        lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
        dimension_numbers, feature_group_count, precision)
    return out, lhs_dim
  else:
    raise NotImplementedError(
        "splitting a convolution along anything but input batch dimension")
Beispiel #5
0
def _all_to_all_split_axis_rule(split_name, vals, params):
    concat_axis = params['concat_axis']
    split_axis = params['split_axis']
    axis_names = params['axis_name']
    assert isinstance(axis_names, tuple)
    x, = vals

    split_pos = list(axis_names).index(split_name)
    before_axes = axis_names[:split_pos]
    after_axes = axis_names[split_pos + 1:]

    # Flatten the split_dim
    split_name_size = psum(1, split_name)
    before_size = psum(1, before_axes)
    after_size = psum(1, after_axes)
    unroll_shape = list(x.shape)
    unroll_shape[split_axis:split_axis +
                 1] = [before_size, split_name_size, after_size]
    unroll_x = lax.reshape(x, unroll_shape)

    if before_axes:
        out_before = all_to_all(unroll_x,
                                before_axes,
                                split_axis,
                                concat_axis=0)
    else:
        out_before = _moveaxis(split_axis, 0, unroll_x)
    out_split = all_to_all(out_before,
                           split_name,
                           split_axis + 1,
                           concat_axis=1)
    if after_axes:
        out_after = all_to_all(out_split,
                               after_axes,
                               split_axis + 2,
                               concat_axis=2)
    else:
        out_after = _moveaxis(split_axis + 2, 2, out_split)

    # Flatten the concat axes and move them to the right position
    y = out_after.reshape((np.prod(out_after.shape[:3]), *out_after.shape[3:]))
    return _moveaxis(0, concat_axis, y)
Beispiel #6
0
def _reshape_papply_rule(name, vals, axes, new_sizes, dimensions, old_sizes):
    operand, = vals
    axis, = axes

    def filter_ones(xs):
        return filter(lambda x: x != 1, xs)

    def find_new_axis(old_axis, old_sizes, new_sizes):
        if len(filter_ones(new_sizes)) != len(filter_ones(old_sizes)):
            return None
        num_before = len(filter_ones(old_sizes[:old_axis]))
        sz = old_sizes[old_axis]
        for i, new_sz in enumerate(new_sizes):
            if num_before == 0:
                if new_sz == sz:
                    return i
                elif new_sz != 1:
                    return None
            elif new_sz != 1:
                num_before -= 1
        return None

    err = NotImplementedError(
        'papply of reshape that would change hidden dimension size')

    if dimensions is None:
        new_axis = find_new_axis(axis, old_sizes, new_sizes)
        if new_axis is not None:
            if (lax.prod(old_sizes[:axis]) != lax.prod(new_sizes[:new_axis])
                    or lax.prod(old_sizes[axis + 1:]) != lax.prod(
                        new_sizes[new_axis + 1:])):
                raise err
            new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:]
            return lax.reshape(operand, new_sizes_,
                               dimensions=dimensions), new_axis
        else:
            raise err
    else:
        raise NotImplementedError('papply of reshape with `dimensions`')