예제 #1
0
def repeat_to_match_shape(x, axis, keepdims):
    """Returns a function that repeats an array along axis to get a given shape.
       Also returns the number of repetitions of the array."""
    assert isinstance(axis, (type(None), int, tuple))

    if not isarray(x):
        return I, 1
    shape = x.shape
    if axis is None:
        dtype=None
        if anp.iscomplexobj(x):
            dtype = getval(anp.array(x)).dtype   # np.full() has a bug for complex numbers
        if keepdims:
            return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape)
        else:
            return lambda g : anp.full(shape, g, dtype=dtype), anp.prod(shape)
    elif isinstance(axis, int):
        if keepdims:
            return lambda g : anp.repeat(g, shape[axis], axis), shape[axis]
        else:
            return lambda g : anp.repeat(anp.expand_dims(g, axis),
                                         shape[axis], axis), shape[axis]
    else:
        repeats  = [shape[i] if i in axis else 1 for i in range(len(shape))]
        expanded = [shape[i] if i not in axis else 1 for i in range(len(shape))]
        num_reps = anp.prod(anp.array(shape)[list(axis)])

        if keepdims:
            return lambda g: anp.tile(g, repeats), num_reps
        else:
            return lambda g: anp.tile(anp.reshape(g, expanded), repeats), num_reps
 def double_val_fun(*args, **kwargs):
     val = fun(*args, **kwargs)
     return val, getval(val)
 def wrapped(*args, **kwargs):
     unboxed_args = map(getval, args)
     unboxed_kwargs = {key: getval(kwargs[key]) for key in kwargs}
     return f(*unboxed_args, **unboxed_kwargs)
예제 #4
0
 def sum_outgrads(outgrads):
     if len(outgrads) is 1 and not isinstance(getval(outgrads[0]), SparseArray):
         return outgrads[0]
     else:
         return primitive_sum_arrays(*outgrads)
 def zeros_like(self):
     return {k : zeros_like(v) for k, v in iteritems(getval(self))}
 def zeros_like(value):
     return [zeros_like(item) for item in getval(value)]
예제 #7
0
def make_grad_tile(ans, x, reps):
    reps = [reps] if anp.isscalar(reps) else reps
    def tile_grad(g):
        for axis, rep in enumerate(reps):
            g = sum(anp.split(g, rep, axis))
        return anp.reshape(g, x.shape)
    return tile_grad
anp.tile.defgrad(make_grad_tile)

def make_grad_transpose(ans, x, axes=None):
    if axes is not None:
        axes = anp.argsort(axes)
    return lambda g : anp.transpose(g, axes)
anp.transpose.defgrad(make_grad_transpose)

isarray = lambda x : isinstance(getval(x), anp.ndarray)

def repeat_to_match_shape(x, axis, keepdims):
    """Returns a function that repeats an array along axis to get a given shape.
       Also returns the number of repetitions of the array."""
    assert isinstance(axis, (type(None), int, tuple))

    if not isarray(x):
        return I, 1
    shape = x.shape
    if axis is None:
        dtype=None
        if anp.iscomplexobj(x):
            dtype = getval(anp.array(x)).dtype   # np.full() has a bug for complex numbers
        if keepdims:
            return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape)