defvjp(anp.tensordot, tensordot_vjp_0, tensordot_vjp_1) defvjp(tensordot_adjoint_0, lambda ans, B, G, axes, An, Bn: lambda A: match_complex(B, tensordot_adjoint_1(A, G, axes, An, Bn)), lambda ans, B, G, axes, An, Bn: lambda A: match_complex(G, anp.tensordot(A, B, axes))) defvjp(tensordot_adjoint_1, lambda ans, A, G, axes, An, Bn: lambda B: match_complex(A, tensordot_adjoint_0(B, G, axes, An, Bn)), lambda ans, A, G, axes, An, Bn: lambda B: match_complex(G, anp.tensordot(A, B, axes))) defvjp(anp.outer, lambda ans, a, b : lambda g: match_complex(a, anp.dot(g, b.T)), lambda ans, a, b : lambda g: match_complex(b, anp.dot(a.T, g))) def grad_concatenate_args(argnum, ans, axis_args, kwargs): axis, args = axis_args[0], axis_args[1:] sizes = [anp.shape(a)[axis] for a in args[:argnum]] start = sum(sizes[:-1]) idxs = [slice(None)] * ans.ndim idxs[axis] = slice(start, start + sizes[-1]) return lambda g: g[tuple(idxs)] defvjp_argnum(anp.concatenate_args, grad_concatenate_args) def wrapped_reshape(x, *args, **kwargs): # The reshape method can be called like A.reshape((5,4)) or A.reshape(5,4). # The reshape function doesn't support both ways, so we have to wrap it. if isinstance(args[0], int): return anp.reshape(x, args, **kwargs) else: return anp.reshape(x, *args, **kwargs) setattr(ArrayBox, 'reshape', wrapped_reshape) def grad_sort(ans, x, axis=-1, kind='quicksort', order=None): #TODO: Cast input with np.asanyarray() if len(x.shape) > 1: raise NotImplementedError( "Gradient of sort not implemented for multi-dimensional arrays.")
defvjp(anp.tensordot, tensordot_vjp_0, tensordot_vjp_1) defvjp(tensordot_adjoint_0, lambda ans, B, G, axes, An, Bn: lambda A: match_complex(B, tensordot_adjoint_1(A, G, axes, An, Bn)), lambda ans, B, G, axes, An, Bn: lambda A: match_complex(G, anp.tensordot(A, B, axes))) defvjp(tensordot_adjoint_1, lambda ans, A, G, axes, An, Bn: lambda B: match_complex(A, tensordot_adjoint_0(B, G, axes, An, Bn)), lambda ans, A, G, axes, An, Bn: lambda B: match_complex(G, anp.tensordot(A, B, axes))) defvjp(anp.outer, lambda ans, a, b : lambda g: match_complex(a, anp.dot(g, b.T)), lambda ans, a, b : lambda g: match_complex(b, anp.dot(a.T, g))) def grad_concatenate_args(argnum, ans, axis_args, kwargs): axis, args = axis_args[0], axis_args[1:] sizes = [anp.shape(a)[axis] for a in args[:argnum]] start = sum(sizes[:-1]) idxs = [slice(None)] * ans.ndim idxs[axis] = slice(start, start + sizes[-1]) return lambda g: g[idxs] defvjp_argnum(anp.concatenate_args, grad_concatenate_args) def wrapped_reshape(x, *args, **kwargs): # The reshape method can be called like A.reshape((5,4)) or A.reshape(5,4). # The reshape function doesn't support both ways, so we have to wrap it. if isinstance(args[0], int): return anp.reshape(x, args, **kwargs) else: return anp.reshape(x, *args, **kwargs) setattr(ArrayBox, 'reshape', wrapped_reshape) def grad_sort(ans, x, axis=-1, kind='quicksort', order=None): #TODO: Cast input with np.asanyarray() if len(x.shape) > 1: raise NotImplementedError( "Gradient of sort not implemented for multi-dimensional arrays.")