コード例 #1
0
def pad_dependency_rule(outstart, outcount, operand, padding_value,
                        padding_config):
    lo, _, interior = unzip3(padding_config)
    dilation = np.array(interior) + 1
    outstart_lo = np.subtract(outstart, lo)
    inclip = lambda indices: np.clip(indices, 0, operand.shape)
    instart = inclip(lax.lax._ceil_divide(outstart_lo, dilation))
    instop = inclip(
        lax.lax._ceil_divide(outstart_lo + outcount.shape, dilation))
    inshape = instop - instart
    insize = prod(inshape)
    offset = instart * dilation - outstart_lo
    limit = offset + np.maximum(0, (np.array(inshape) - 1) * dilation + 1)
    incount = Ones(inshape) if is_ones(outcount) else laxref.slice(
        outcount, offset, limit, dilation) if insize else None
    padcount = outcount.size - insize

    def outslice(inslice, padding_value):
        assert inslice is None or np.array_equal(inslice.shape, inshape)
        return (lax.pad(
            inslice, padding_value,
            zip(offset,
                np.array(outcount.shape) - limit, interior)) if insize else
                jnp.full(outcount.shape, padding_value, operand.dtype))

    return ([(instart, inshape) if insize else None,
             ([], [])], [incount, padcount], outslice)
コード例 #2
0
ファイル: jax2tf.py プロジェクト: hereismari/jax
def _pad(operand, padding_value, padding_config):
  low, high, interior = util.unzip3(padding_config)
  if all(lo >= 0 and hi >= 0 and i == 0 for lo, hi, i in padding_config):
    return tf.pad(operand, util.safe_zip(low, high),
                  mode="CONSTANT", constant_values=padding_value)
  # TODO(necula): implement shape inference for XlaPad
  out_shape = _pad_shape(operand, padding_value, padding_config)
  out = tfxla.pad(operand, padding_value, low, high, interior)
  out.set_shape(out_shape)
  return out
コード例 #3
0
def init_interpreter(rng, jaxpr, consts, freevar_vals, net_params, *args):
    def read(v):
        if type(v) is jc.Literal:
            return v.val
        else:
            return env[v]

    def write(v, val):
        env[v] = val

    env = {}
    write(jc.unitvar, jc.unit)
    jc.pat_fmap(write, jaxpr.constvars, consts)
    jc.pat_fmap(write, jaxpr.invars, args)
    jc.pat_fmap(write, jaxpr.freevars, freevar_vals)
    for eqn in jaxpr.eqns:
        rng, prim_rng = random.split(rng)
        if not eqn.restructure:
            in_vals = map(read, eqn.invars)
        else:
            in_vals = [
                pack(map(read, invars))
                if type(invars) is tuple else read(invars)
                for invars in eqn.invars
            ]
        if eqn.bound_subjaxprs:
            subjaxprs, sub_consts, sub_freevar_vals = unzip3([
                (subjaxpr, map(read, const_vars), map(read, bound_vars))
                for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs
            ])
            ans, net_params = get_primitive_init(
                eqn.primitive)(prim_rng, eqn.params, sub_consts,
                               sub_freevar_vals, in_vals, net_params)
        else:
            ans, net_params = get_primitive_init(eqn.primitive)(prim_rng,
                                                                net_params,
                                                                *in_vals,
                                                                **eqn.params)
        outvals = list(ans) if eqn.destructure else [ans]
        map(write, eqn.outvars, outvals)
    return net_params
コード例 #4
0
ファイル: jax_to_tf.py プロジェクト: qiuminxu/jax
def _pad(operand, padding_value, padding_config):
    low, high, interior = util.unzip3(padding_config)
    out_shape = _pad_shape(operand, padding_value, padding_config)
    out = tfxla.pad(operand, padding_value, low, high, interior)
    out.set_shape(out_shape)
    return out