def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions, old_sizes): operand, = vals axis, = axes def filter_ones(xs): return filter(lambda x: x != 1, xs) def find_new_axis(old_axis, old_sizes, new_sizes): left = onp.prod(old_sizes[:old_axis]) size = old_sizes[old_axis] prod = 1 for i, cur_sz in enumerate(new_sizes): if prod == left and cur_sz == size: return i prod = prod * cur_sz return None if dimensions is None: new_axis = find_new_axis(axis, old_sizes, new_sizes) if new_axis is not None: new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:] return lax.reshape(operand, new_sizes_, dimensions=dimensions), new_axis else: raise NotImplementedError( 'papply of reshape that would change hidden dimension size') else: raise NotImplementedError('papply of reshape with `dimensions`')
def _update_arrays(i, aval, xs, x): assert isinstance(aval, core.AbstractValue) if isinstance(aval, core.AbstractTuple): return core.pack(map(partial(_update_arrays, i), aval, xs, x)) else: x = lax.reshape(x, (1,) + onp.shape(x)) return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
def chooser_taylor_rule(primals_in, series_in, **params): operand, = primals_in gs, = series_in primal_out = chooser_fun(operand, **params) axes = params.pop("axes", None) primal_dtype = gs[0].dtype shape = [1 if i in axes else d for i, d in enumerate(operand.shape)] location_indicators = lax.convert_element_type( lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype) counts = lax._reduce_sum(location_indicators, axes) def _reduce_chooser_taylor_rule(g): return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts) series_out = [_reduce_chooser_taylor_rule(g) for g in gs] return primal_out, series_out
def _conv_general_dilated_papply_rule( name, size, vals, dims, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision, **unused_kwargs): lhs, rhs = vals lhs_dim, rhs_dim = dims lhs_spec_batch_dim = dimension_numbers.lhs_spec[0] if rhs_dim is None and lhs_dim == lhs_spec_batch_dim: lhs = lax.reshape(lhs, tuple(onp.insert(lhs.shape, lhs_dim, 1))) out = lax.conv_general_dilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision) return out, lhs_dim else: raise NotImplementedError( "splitting a convolution along anything but input batch dimension")
def _all_to_all_split_axis_rule(split_name, vals, params): concat_axis = params['concat_axis'] split_axis = params['split_axis'] axis_names = params['axis_name'] assert isinstance(axis_names, tuple) x, = vals split_pos = list(axis_names).index(split_name) before_axes = axis_names[:split_pos] after_axes = axis_names[split_pos + 1:] # Flatten the split_dim split_name_size = psum(1, split_name) before_size = psum(1, before_axes) after_size = psum(1, after_axes) unroll_shape = list(x.shape) unroll_shape[split_axis:split_axis + 1] = [before_size, split_name_size, after_size] unroll_x = lax.reshape(x, unroll_shape) if before_axes: out_before = all_to_all(unroll_x, before_axes, split_axis, concat_axis=0) else: out_before = _moveaxis(split_axis, 0, unroll_x) out_split = all_to_all(out_before, split_name, split_axis + 1, concat_axis=1) if after_axes: out_after = all_to_all(out_split, after_axes, split_axis + 2, concat_axis=2) else: out_after = _moveaxis(split_axis + 2, 2, out_split) # Flatten the concat axes and move them to the right position y = out_after.reshape((np.prod(out_after.shape[:3]), *out_after.shape[3:])) return _moveaxis(0, concat_axis, y)
def _reshape_papply_rule(name, vals, axes, new_sizes, dimensions, old_sizes): operand, = vals axis, = axes def filter_ones(xs): return filter(lambda x: x != 1, xs) def find_new_axis(old_axis, old_sizes, new_sizes): if len(filter_ones(new_sizes)) != len(filter_ones(old_sizes)): return None num_before = len(filter_ones(old_sizes[:old_axis])) sz = old_sizes[old_axis] for i, new_sz in enumerate(new_sizes): if num_before == 0: if new_sz == sz: return i elif new_sz != 1: return None elif new_sz != 1: num_before -= 1 return None err = NotImplementedError( 'papply of reshape that would change hidden dimension size') if dimensions is None: new_axis = find_new_axis(axis, old_sizes, new_sizes) if new_axis is not None: if (lax.prod(old_sizes[:axis]) != lax.prod(new_sizes[:new_axis]) or lax.prod(old_sizes[axis + 1:]) != lax.prod( new_sizes[new_axis + 1:])): raise err new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:] return lax.reshape(operand, new_sizes_, dimensions=dimensions), new_axis else: raise err else: raise NotImplementedError('papply of reshape with `dimensions`')