예제 #1
0
defvjp(anp.tensordot, tensordot_vjp_0, tensordot_vjp_1)
defvjp(tensordot_adjoint_0, lambda ans, B, G, axes, An, Bn: lambda A: match_complex(B, tensordot_adjoint_1(A, G, axes, An, Bn)),
                            lambda ans, B, G, axes, An, Bn: lambda A: match_complex(G, anp.tensordot(A, B, axes)))
defvjp(tensordot_adjoint_1, lambda ans, A, G, axes, An, Bn: lambda B: match_complex(A, tensordot_adjoint_0(B, G, axes, An, Bn)),
                            lambda ans, A, G, axes, An, Bn: lambda B: match_complex(G, anp.tensordot(A, B, axes)))
defvjp(anp.outer, lambda ans, a, b : lambda g: match_complex(a, anp.dot(g, b.T)),
                  lambda ans, a, b : lambda g: match_complex(b, anp.dot(a.T, g)))

def grad_concatenate_args(argnum, ans, axis_args, kwargs):
    axis, args = axis_args[0], axis_args[1:]
    sizes = [anp.shape(a)[axis] for a in args[:argnum]]
    start = sum(sizes[:-1])
    idxs = [slice(None)] * ans.ndim
    idxs[axis] = slice(start, start + sizes[-1])
    return lambda g: g[tuple(idxs)]
defvjp_argnum(anp.concatenate_args, grad_concatenate_args)

def wrapped_reshape(x, *args, **kwargs):
    # The reshape method can be called like A.reshape((5,4)) or A.reshape(5,4).
    # The reshape function doesn't support both ways, so we have to wrap it.
    if isinstance(args[0], int):
        return anp.reshape(x, args, **kwargs)
    else:
        return anp.reshape(x, *args, **kwargs)
setattr(ArrayBox, 'reshape', wrapped_reshape)

def grad_sort(ans, x, axis=-1, kind='quicksort', order=None):
    #TODO: Cast input with np.asanyarray()
    if len(x.shape) > 1:
        raise NotImplementedError(
            "Gradient of sort not implemented for multi-dimensional arrays.")
예제 #2
0
defvjp(anp.tensordot, tensordot_vjp_0, tensordot_vjp_1)
defvjp(tensordot_adjoint_0, lambda ans, B, G, axes, An, Bn: lambda A: match_complex(B, tensordot_adjoint_1(A, G, axes, An, Bn)),
                            lambda ans, B, G, axes, An, Bn: lambda A: match_complex(G, anp.tensordot(A, B, axes)))
defvjp(tensordot_adjoint_1, lambda ans, A, G, axes, An, Bn: lambda B: match_complex(A, tensordot_adjoint_0(B, G, axes, An, Bn)),
                            lambda ans, A, G, axes, An, Bn: lambda B: match_complex(G, anp.tensordot(A, B, axes)))
defvjp(anp.outer, lambda ans, a, b : lambda g: match_complex(a, anp.dot(g, b.T)),
                  lambda ans, a, b : lambda g: match_complex(b, anp.dot(a.T, g)))

def grad_concatenate_args(argnum, ans, axis_args, kwargs):
    axis, args = axis_args[0], axis_args[1:]
    sizes = [anp.shape(a)[axis] for a in args[:argnum]]
    start = sum(sizes[:-1])
    idxs = [slice(None)] * ans.ndim
    idxs[axis] = slice(start, start + sizes[-1])
    return lambda g: g[idxs]
defvjp_argnum(anp.concatenate_args, grad_concatenate_args)

def wrapped_reshape(x, *args, **kwargs):
    # The reshape method can be called like A.reshape((5,4)) or A.reshape(5,4).
    # The reshape function doesn't support both ways, so we have to wrap it.
    if isinstance(args[0], int):
        return anp.reshape(x, args, **kwargs)
    else:
        return anp.reshape(x, *args, **kwargs)
setattr(ArrayBox, 'reshape', wrapped_reshape)

def grad_sort(ans, x, axis=-1, kind='quicksort', order=None):
    #TODO: Cast input with np.asanyarray()
    if len(x.shape) > 1:
        raise NotImplementedError(
            "Gradient of sort not implemented for multi-dimensional arrays.")