def repeat_to_match_shape(x, axis, keepdims): """Returns a function that repeats an array along axis to get a given shape. Also returns the number of repetitions of the array.""" assert isinstance(axis, (type(None), int, tuple)) if not isarray(x): return I, 1 shape = x.shape if axis is None: dtype=None if anp.iscomplexobj(x): dtype = getval(anp.array(x)).dtype # np.full() has a bug for complex numbers if keepdims: return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape) else: return lambda g : anp.full(shape, g, dtype=dtype), anp.prod(shape) elif isinstance(axis, int): if keepdims: return lambda g : anp.repeat(g, shape[axis], axis), shape[axis] else: return lambda g : anp.repeat(anp.expand_dims(g, axis), shape[axis], axis), shape[axis] else: repeats = [shape[i] if i in axis else 1 for i in range(len(shape))] expanded = [shape[i] if i not in axis else 1 for i in range(len(shape))] num_reps = anp.prod(anp.array(shape)[list(axis)]) if keepdims: return lambda g: anp.tile(g, repeats), num_reps else: return lambda g: anp.tile(anp.reshape(g, expanded), repeats), num_reps
def double_val_fun(*args, **kwargs): val = fun(*args, **kwargs) return val, getval(val)
def wrapped(*args, **kwargs): unboxed_args = map(getval, args) unboxed_kwargs = {key: getval(kwargs[key]) for key in kwargs} return f(*unboxed_args, **unboxed_kwargs)
def sum_outgrads(outgrads): if len(outgrads) is 1 and not isinstance(getval(outgrads[0]), SparseArray): return outgrads[0] else: return primitive_sum_arrays(*outgrads)
def zeros_like(self): return {k : zeros_like(v) for k, v in iteritems(getval(self))}
def zeros_like(value): return [zeros_like(item) for item in getval(value)]
def make_grad_tile(ans, x, reps): reps = [reps] if anp.isscalar(reps) else reps def tile_grad(g): for axis, rep in enumerate(reps): g = sum(anp.split(g, rep, axis)) return anp.reshape(g, x.shape) return tile_grad anp.tile.defgrad(make_grad_tile) def make_grad_transpose(ans, x, axes=None): if axes is not None: axes = anp.argsort(axes) return lambda g : anp.transpose(g, axes) anp.transpose.defgrad(make_grad_transpose) isarray = lambda x : isinstance(getval(x), anp.ndarray) def repeat_to_match_shape(x, axis, keepdims): """Returns a function that repeats an array along axis to get a given shape. Also returns the number of repetitions of the array.""" assert isinstance(axis, (type(None), int, tuple)) if not isarray(x): return I, 1 shape = x.shape if axis is None: dtype=None if anp.iscomplexobj(x): dtype = getval(anp.array(x)).dtype # np.full() has a bug for complex numbers if keepdims: return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape)