def test_gpu_node_sqrt(a): set_cuda_active(True) g1 = Variable(a) g3 = sum(sqrt(g1)) g = g3.grad() g_g1 = g.get(g1) g3.to_cpu() set_cuda_active(False) c3 = sum(sqrt(g1)) c = c3.grad() c_g1 = c.get(g1) close(g3, c3) close(c_g1, g_g1)
def _get_cpu(self, dy, node): node_id = id(node) pdy = self._params.get(node_id, None) if pdy is None: pdy = { 'pmse': 0, 'pra': 0, } pmse = pdy['pmse'] pra = pdy['pra'] r = self._g * pmse + (1 - self._g) * (dy**2) k = self._ra * pra + (1 - self._ra) * (dy) v = (r - k**2) if hasattr(v, "as_ndarray"): v = v.as_ndarray() v[v < 0] = 0 ret = self._lr * dy / sqrt(v + self._epsilon) self._params[node_id] = { 'pmse': r, 'pra': k, } if isinstance(ret, Node): ret.detach_graph() return ret
def __call__(self, dy, node): node_id = id(node) pdy = self._params.get(node_id, None) if pdy is None: b = self._b g = self._g u = (1 - self._b) * dy r = (1 - self._g) * (dy**2) else: u = pdy["u"] r = pdy["r"] b = pdy["beta"] g = pdy["ganma"] u.setflags(write=True) r.setflags(write=True) if not is_cuda_active(): min_flug = np.where(np.abs(u) < self._min, True, False) min_flug = np.where(np.abs(r) < self._min, True, False) u[min_flug] = 0 r[min_flug] = 0 u = self._b * u + (1 - self._b) * dy r = self._g * r + (1 - self._g) * (dy**2) self._params[node_id] = { "beta": b * self._b, "ganma": g * self._g, "u": u, "r": r } ret = self._lr * u / (sqrt(r / (1 - g)) + self._epsilon) / (1 - b) if isinstance(ret, Node): ret.detach_graph() return ret
def __call__(self, dy, node): node_id = id(node) pdy = self._params.get(node_id, 0) r = self._g * pdy + (1 - self._g) * (dy**2) ret = self._lr * dy / (sqrt(r) + self._epsilon) self._params[node_id] = r if isinstance(ret, Node): ret.detach_graph() return ret
def _get_cpu(self, dy, node): node_id = id(node) pdy = self._params.get(node_id, None) if pdy is None: pdy = { 'pmse': 0, } pmse = pdy['pmse'] r = self._g * pmse + (1 - self._g) * (dy**2) ret = self._lr * dy / (sqrt(r) + self._epsilon) self._params[node_id] = { 'pmse': r, } if isinstance(ret, Node): ret.detach_graph() return ret
def normalized_form(x): return op.sqrt(op.sum(op.square(x), keepdims=True))