def pad_dependency_rule(outstart, outcount, operand, padding_value, padding_config): lo, _, interior = unzip3(padding_config) dilation = np.array(interior) + 1 outstart_lo = np.subtract(outstart, lo) inclip = lambda indices: np.clip(indices, 0, operand.shape) instart = inclip(lax.lax._ceil_divide(outstart_lo, dilation)) instop = inclip( lax.lax._ceil_divide(outstart_lo + outcount.shape, dilation)) inshape = instop - instart insize = prod(inshape) offset = instart * dilation - outstart_lo limit = offset + np.maximum(0, (np.array(inshape) - 1) * dilation + 1) incount = Ones(inshape) if is_ones(outcount) else laxref.slice( outcount, offset, limit, dilation) if insize else None padcount = outcount.size - insize def outslice(inslice, padding_value): assert inslice is None or np.array_equal(inslice.shape, inshape) return (lax.pad( inslice, padding_value, zip(offset, np.array(outcount.shape) - limit, interior)) if insize else jnp.full(outcount.shape, padding_value, operand.dtype)) return ([(instart, inshape) if insize else None, ([], [])], [incount, padcount], outslice)
def _pad(operand, padding_value, padding_config): low, high, interior = util.unzip3(padding_config) if all(lo >= 0 and hi >= 0 and i == 0 for lo, hi, i in padding_config): return tf.pad(operand, util.safe_zip(low, high), mode="CONSTANT", constant_values=padding_value) # TODO(necula): implement shape inference for XlaPad out_shape = _pad_shape(operand, padding_value, padding_config) out = tfxla.pad(operand, padding_value, low, high, interior) out.set_shape(out_shape) return out
def init_interpreter(rng, jaxpr, consts, freevar_vals, net_params, *args): def read(v): if type(v) is jc.Literal: return v.val else: return env[v] def write(v, val): env[v] = val env = {} write(jc.unitvar, jc.unit) jc.pat_fmap(write, jaxpr.constvars, consts) jc.pat_fmap(write, jaxpr.invars, args) jc.pat_fmap(write, jaxpr.freevars, freevar_vals) for eqn in jaxpr.eqns: rng, prim_rng = random.split(rng) if not eqn.restructure: in_vals = map(read, eqn.invars) else: in_vals = [ pack(map(read, invars)) if type(invars) is tuple else read(invars) for invars in eqn.invars ] if eqn.bound_subjaxprs: subjaxprs, sub_consts, sub_freevar_vals = unzip3([ (subjaxpr, map(read, const_vars), map(read, bound_vars)) for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs ]) ans, net_params = get_primitive_init( eqn.primitive)(prim_rng, eqn.params, sub_consts, sub_freevar_vals, in_vals, net_params) else: ans, net_params = get_primitive_init(eqn.primitive)(prim_rng, net_params, *in_vals, **eqn.params) outvals = list(ans) if eqn.destructure else [ans] map(write, eqn.outvars, outvals) return net_params
def _pad(operand, padding_value, padding_config): low, high, interior = util.unzip3(padding_config) out_shape = _pad_shape(operand, padding_value, padding_config) out = tfxla.pad(operand, padding_value, low, high, interior) out.set_shape(out_shape) return out