def _select_and_gather_add_transpose(t, tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): assert select_prim in (lax.le_p, lax.ge_p) assert (ad.is_undefined_primal(tangents) and not ad.is_undefined_primal(operand)) if any(d != 1 for d in window_dilation): msg = ( "VJP not implemented for select_and_gather (MaxPool) with window " "dilation, got window_dilation={}.") raise NotImplementedError(msg.format(window_dilation)) if type(t) is ad_util.Zero: return [ad_util.Zero(tangents.aval), None] has_base_dilation = any(d != 1 for d in base_dilation) if has_base_dilation: select_identity = (lax._get_max_identity if select_prim is lax.ge_p else lax._get_min_identity) operand = lax.pad(operand, select_identity(operand.dtype), tuple((0, 0, d - 1) for d in base_dilation)) result = _select_and_scatter_add(t, operand, select_prim, window_dimensions, window_strides, padding) if has_base_dilation: result = slicing.slice(result, (0, ) * len(result.shape), result.shape, base_dilation) return [result, None]
def _select_and_scatter_add_transpose( t, source, operand, *, select_prim, window_dimensions, window_strides, padding): assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand) if type(t) is ad_util.Zero: return [ad_util.Zero(source.aval), None] ones = (1,) * len(window_dimensions) source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions, window_strides, padding, ones, ones) return [source_t, None]
def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars] in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts_ = iter(in_cts) in_cts = [next(in_cts_) if ad.is_undefined_primal(x) else ad_util.Zero(x.aval) for x in in_primals] assert next(in_cts_, None) is None in_cts, cell.treedef = tree_flatten(in_cts) return in_cts