def def_grads(reg, prims): def identity(x): return x # dot prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T)) prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(a.T, g), argnum=1) # non-linear #prims.tanh.def_grad(lambda ans, x: lambda g: g / np.cosh(x) ** 2) prims('exp').def_grad(lambda ans, x: lambda g: g * ans) prims('log').def_grad(lambda ans, x: lambda g: g / x) # reduce prims('sum').def_grad(lambda ans, x: lambda g: ndarray.full(x.shape, g, x.context)) # + - * / prims('multiply').def_grad(lambda ans, x, y: unbroadcast(ans, x, lambda g: g * y)) prims('multiply').def_grad(lambda ans, x, y: unbroadcast(ans, y, lambda g: x * g), argnum=1) prims('add').def_grad(lambda ans, x, y: unbroadcast(ans, x, identity)) prims('add').def_grad(lambda ans, x, y: unbroadcast(ans, y, identity), argnum=1) prims('subtract').def_grad(lambda ans, x, y: unbroadcast(ans, x, identity)) prims('subtract').def_grad(lambda ans, x, y: unbroadcast(ans, y, operator.neg), argnum=1) prims('divide').def_grad(lambda ans, x, y: unbroadcast(ans, x, lambda g: g / y)) prims('divide').def_grad(lambda ans, x, y: unbroadcast(ans, y, lambda g: - g * x / (y * y)), argnum=1) prims('true_divide').def_grad(lambda ans, x, y: unbroadcast(ans, x, lambda g: g / y)) prims('true_divide').def_grad(lambda ans, x, y: unbroadcast(ans, y, lambda g: - g * x / (y * y)), argnum=1) # power #prims.power.def_grad(lambda ans, x, y : unbroadcast(ans, x, lambda g : g * y * x ** (y - 1))) #prims.power.def_grad(lambda ans, x, y : unbroadcast(ans, y, lambda g : g * ndarray.log(x) * x ** y), argnum=1) # mod #prims.mod.def_grad(lambda ans, x, y : unbroadcast(ans, x, identity)) #prims.mod.def_grad(lambda ans, x, y : unbroadcast(ans, y, lambda g : - g * ndarray.floor(x/y)), argnum=1) # negate prims('negative').def_grad(lambda ans, x: operator.neg)
def gen_sum_grad(ans, x, axis, keepdims): xshape = list(x.shape) if axis is None: return lambda g: ndarray.full(x.shape, g, x.context) if type(axis) is int: axis = [axis] elif type(axis) is tuple: axis = list(axis) for a in axis: xshape[a] = 1 def sum_grad(g): return ndarray.zeros(x.shape, ctx=g.context) + g.reshape(tuple(xshape)) return sum_grad
def _viterbi_decode(self, feats): backpointers = [] # Initialize the viterbi variables in log space vvars = nd.full((1, self.tagset_size), -10000.) vvars[0, self.tag2idx[START_TAG]] = 0 for feat in feats: bptrs_t = [] # holds the backpointers for this step viterbivars_t = [] # holds the viterbi variables for this step for next_tag in range(self.tagset_size): # next_tag_var[i] holds the viterbi variable for tag i at the # previous step, plus the score of transitioning # from tag i to next_tag. # We don't include the emission scores here because the max # does not depend on them (we add them in below) next_tag_var = vvars + self.transitions[next_tag] best_tag_id = argmax(next_tag_var) bptrs_t.append(best_tag_id) viterbivars_t.append(next_tag_var[0, best_tag_id]) # Now add in the emission scores, and assign vvars to the set # of viterbi variables we just computed vvars = (nd.concat(*viterbivars_t, dim=0) + feat).reshape((1, -1)) backpointers.append(bptrs_t) # Transition to STOP_TAG terminal_var = vvars + self.transitions[self.tag2idx[STOP_TAG]] best_tag_id = argmax(terminal_var) path_score = terminal_var[0, best_tag_id] # Follow the back pointers to decode the best path. best_path = [best_tag_id] for bptrs_t in reversed(backpointers): best_tag_id = bptrs_t[best_tag_id] best_path.append(best_tag_id) # Pop off the start tag (we dont want to return that to the caller) start = best_path.pop() assert start == self.tag2idx[START_TAG] # Sanity check best_path.reverse() return path_score, best_path
def grad(g): if isinstance(g, float) or isinstance(g, int): return ndarray.full(x.shape, g, x.context) else: return ndarray.full(x.shape, g.asscalar(), x.context)
def full_1d(length, fill_value, dtype, ctx): return nd.full((length, ), fill_value, dtype=dtype, ctx=ctx)
def __call__(self, p: float, p_hat: NDArray, y: NDArray, y_hat: NDArray) -> NDArray: return self._alpha * mean(self._bce(p_hat, full( p_hat.shape, p))) + self._beta * mean(self._l1(y_hat, y))
def _update_impl(self, indices, weight, grad, state, multi_precision=False): """update function""" try: from mxnet.ndarray.contrib import adamw_update except ImportError: raise ImportError( 'Failed to import nd.contrib.adamw_update from MXNet. ' 'BERTAdam optimizer requires mxnet>=1.5.0b20190220. ' 'Please upgrade your MXNet version. For example: ' 'pip uninstall mxnet-cu90 -y; pip install mxnet-cu90 --pre') self._update_count(indices) lr = self._get_lr(indices) wd = self._get_wd(indices) # pylint: disable=access-member-before-definition if not isinstance(self.rescale_grad, NDArray): self.rescale_grad = full(shape=(1, ), val=self.rescale_grad, ctx=weight.context) else: self.rescale_grad = self.rescale_grad.as_in_context(weight.context) kwargs = { 'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 'rescale_grad': self.rescale_grad } if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient if not multi_precision: mean, var = state adamw_update(weight, grad, mean, var, out=weight, lr=1, wd=wd, eta=lr, **kwargs) else: try: from mxnet.ndarray.contrib import mp_adamw_update except ImportError: raise ImportError( 'Failed to import ' 'nd.contrib.mp_adamw_update from MXNet. ' 'BERTAdam optimizer requires mxnet>=1.5.0b20190220. ' 'Please upgrade your MXNet version. For example: ' 'pip uninstall mxnet-cu90 -y; pip install mxnet-cu90 --pre' ) mean, var = state[0] mp_adamw_update(weight, grad, mean, var, state[1], out=weight, lr=1, wd=wd, eta=lr, **kwargs)
def __getitem__(self, idx): return nd.array(np.random.rand(1, 32)), nd.full(1, random.randint(0, 10), dtype="float32")
def full(shape, fill_value, dtype, ctx): return nd.full(shape, fill_value, dtype=dtype, ctx=ctx)
def nd_full(*args, **kwargs): """ MRT wrapper method for mxnet.NDArray.full. """ return nd.full(*args, dtype="float64", **kwargs)
def nd_full(*args, **kwargs): return nd.full(*args, dtype="float64", **kwargs)