def instantiate_const_abstracted(self, tracer): pv, const = tracer.pval if isinstance(pv, jax_core.AbstractValue): return tracer elif pv is None: aval = abstract_arrays.raise_to_shaped( trace_util.get_shaped_aval(const), onp.isscalar(const)) return UnzipTracer(self, pe.PartialVal.unknown(aval), pe.ConstVar(const), tracer.is_key()) else: raise TypeError(pv)
def new_instantiated_const(self, val): return UnzipTracer( self, pe.PartialVal.unknown(trace_util.get_shaped_aval(val)), pe.ConstVar(val), True)