Exemple #1
0
def def_grads(prims):
    """ Define gradient function for primitives """
    identity = lambda x: x
    # dot
    prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T))
    prims('dot').def_grad(
        lambda ans, a, b: lambda g: ndarray.dot(a.T, g), argnum=1)
    # non-linear
    prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans ** 2))
    prims('exp').def_grad(lambda ans, x: lambda g: g * ans)
    prims('log').def_grad(lambda ans, x: lambda g: g / x)
    # reduce
    prims('sum').def_grad(_sum_grad)
    # + - * /
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y))
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1)
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('add').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, identity), argnum=1)
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1)
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('maximum').def_grad(_maximum_grad_gen0)
    prims('maximum').def_grad(_maximum_grad_gen1, argnum=1)
    # TODO: minjie
    prims('max').def_grad_zero()
    # negate
    prims('negative').def_grad(lambda ans, x: operator.neg)
    prims('transpose').def_grad(lambda ans, x: mxnet.nd.transpose)
    prims('abs').def_grad(lambda ans, x: lambda g: mxnet.nd.sign(x) * g)
    prims('sign').def_grad_zero()
    prims('round').def_grad_zero()
    prims('ceil').def_grad_zero()
    prims('floor').def_grad_zero()
    prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mxnet.nd.sqrt(x))
    prims('sin').def_grad(lambda ans, x: lambda g: g * mxnet.nd.cos(x))
    prims('cos').def_grad(lambda ans, x: lambda g: -g * mxnet.nd.sin(x))
    prims('power').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y * mxnet.nd.power(x, y - 1))
    )
    prims('power').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: g * mxnet.nd.log(x) * ans),
        argnum=1)
    prims('reshape').def_grad(
        lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape))
    prims('expand_dims').def_grad(
        lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
Exemple #2
0
def def_grads(prims):
    """ Define gradient function for primitives """
    identity = lambda x: x
    # dot
    prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T))
    prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(a.T, g),
                          argnum=1)
    # non-linear
    #prims.tanh.def_grad(lambda ans, x: lambda g: g / np.cosh(x) ** 2)
    prims('exp').def_grad(lambda ans, x: lambda g: g * ans)
    prims('log').def_grad(lambda ans, x: lambda g: g / x)
    # reduce
    prims('sum').def_grad(_sum_grad)
    # + - * /
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y))
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1)
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, y, identity),
                          argnum=1)
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1)
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('maximum').def_grad(_maximum_grad_gen0)
    prims('maximum').def_grad(_maximum_grad_gen1, argnum=1)
    # TODO: minjie
    prims('max').def_grad_zero()
    # negate
    prims('negative').def_grad(lambda ans, x: operator.neg)
    prims('transpose').def_grad(lambda ans, x: mxnet.nd.transpose)
    prims('abs').def_grad(lambda ans, x: lambda g: mxnet.nd.sign(x) * g)
    prims('sign').def_grad_zero()
    prims('round').def_grad_zero()
    prims('ceil').def_grad_zero()
    prims('floor').def_grad_zero()
    prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mxnet.nd.sqrt(x))
    prims('sin').def_grad(lambda ans, x: lambda g: g * mxnet.nd.cos(x))
    prims('cos').def_grad(lambda ans, x: lambda g: -g * mxnet.nd.sin(x))
    prims('power').def_grad(lambda ans, x, y: _unbroadcast(
        ans, x, lambda g: g * y * mxnet.nd.power(x, y - 1)))
    prims('power').def_grad(lambda ans, x, y: _unbroadcast(
        ans, y, lambda g: g * mxnet.nd.log(x) * ans),
                            argnum=1)
    prims('reshape').def_grad(
        lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape))
    prims('expand_dims').def_grad(
        lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
Exemple #3
0
    def forward(self,
                encoder_output: nd.NDArray,
                label=None,
                label_lengths=None):
        no_label = label is None or label_lengths is None

        encoder_output = nd.transpose(encoder_output, (0, 2, 3, 1))
        encoder_output = encoder_output.reshape(
            (encoder_output.shape[0], -1, encoder_output.shape[3]))
        batch_max_len = self.max_len if no_label else int(
            label_lengths.max().asscalar()) - 1

        # Initialize hidden states
        encoder_output_mean = encoder_output.mean(axis=1)
        h = self.init_h(encoder_output_mean)
        c = self.init_c(encoder_output_mean)

        # Two tensors to store outputs
        predictions = []
        alphas = []

        if not no_label:
            label_embedded = self.embedding(label)
        else:
            bs = encoder_output.shape[0]
            x_t = self.embedding(
                nd.zeros(shape=(bs, ), ctx=encoder_output.context))
        for t in range(batch_max_len):
            if not no_label:
                x_t = label_embedded[:, t]
            if self._use_current_state:
                _, [h, c] = self.lstm_cell(x_t, [h, c])
                if self._use_adaptive_attention:
                    atten_weights, alpha = self.attention(
                        encoder_output, h, x_t, c)
                else:
                    atten_weights, alpha = self.attention(encoder_output, h)
                atten_weights = self.f_beta(h).sigmoid() * atten_weights
                inputs = nd.concat(atten_weights, h, dim=1)
                preds = self.out(self.dropout(inputs))
            else:
                atten_weights, alpha = self.attention(encoder_output, h)
                atten_weights = nd.sigmoid(self.f_beta(h)) * atten_weights
                inputs = nd.concat(x_t, atten_weights, dim=1)
                _, [h, c] = self.lstm_cell(inputs, [h, c])
                preds = self.out(self.dropout(h))
            x_t = self.embedding(preds.argmax(axis=1))
            predictions.append(preds)
            alphas.append(alpha)
        predictions = nd.concat(*[x.expand_dims(axis=1) for x in predictions],
                                dim=1)
        alphas = nd.concat(*[x.expand_dims(axis=1) for x in alphas], dim=1)

        return predictions, alphas
Exemple #4
0
 def stat_helper(name, array):
     """wrapper for executor callback"""
     import ctypes
     from mxnet.ndarray import NDArray
     from mxnet.base import NDArrayHandle, py_str
     array = ctypes.cast(array, NDArrayHandle)
     array = NDArray(array, writable=False)
     array.wait_to_read()
     elapsed = float(time.time()-stat_helper.start_time)*1000
     if elapsed>.01:
         print (name, array.shape, ('%.1fms' % (elapsed,)))
     stat_helper.start_time=time.time()
Exemple #5
0
 def stat_helper(name, array):
     """wrapper for executor callback"""
     import ctypes
     from mxnet.ndarray import NDArray
     from mxnet.base import NDArrayHandle, py_str
     array = ctypes.cast(array, NDArrayHandle)
     array = NDArray(array, writable=False)
     array.wait_to_read()
     elapsed = float(time.time()-stat_helper.start_time)*1000
     # if elapsed>0.:
     #     print (name, array.shape, ('%.1fms' % (elapsed,)))
     # stat_helper.start_time=time.time()
     array = array.asnumpy()
     print(name, array.shape, np.average(array), np.std(array), ('%.1fms' % (float(time.time()-stat_helper.start_time)*1000)))
     stat_helper.internal_shapes.append((name,array.shape))
    def _partition_tensor(self, size):
        """Zero-copy implementation.
        Note: ndarray works for up to ~4 billion parameters.

        Below 2 lines are buggy with horovod -- causing bad performance.
            tmp = self._tensor.reshape(-1, 1)
            avatar = tmp[start:end]
        """
        number = (self._tensor.size - 1) // size + 1
        if number > self._tensor.shape[0]:
            self._logger.warning(
                "The number of tensor rows (with shape {}) is smaller than partition number {}."
                .format(self._tensor.shape, number))
            number = self._tensor.shape[0]
        num_per_partition = self._tensor.shape[0] // number
        partitions_with_extra = self._tensor.shape[0] % number

        partitions = []
        start = 0
        end = num_per_partition
        for i in range(number):
            handle = NDArrayHandle()
            check_call(
                BYTESCHEDULER_LIB.bytescheduler_get_ndarray_avatar(
                    self._tensor.handle, byref(handle)))
            avatar = NDArray(handle)[start:end]
            partitions.append(avatar)
            start = end
            end += num_per_partition
            if i >= number - partitions_with_extra - 1:
                end += 1
        return partitions
    def _prepare(self):
        """Post start barrier OP, start OP, comm OP, end OP and end barrier OP to MXNet engine. The function of each
        kind of OP is explained below.
        start barrier OP: barrier the start of a parent ByteTask, used to maintain original dependency.
        start OP: It notifies Core about task readiness. It is also used to delay the start of a child ByteTask.
        comm OP: the OP that does real communication, e.g., push, pull, allreduce.
        end OP: an OP that runs after a child ByteTask is finished. It notifies Core about the task completion.
        end barrier OP: an OP that runs after the parent ByteTask is finished, used to maintain original dependency.
        """
        if self.parent is None:
            real = self._tensor.handle
            avatar = NDArrayHandle()
            check_call(
                BYTESCHEDULER_LIB.bytescheduler_get_ndarray_avatar(
                    real, byref(avatar)))
            self._avatar = NDArray(avatar)
            avatar = self._avatar.handle
        else:
            real = self.parent._tensor.handle
            avatar = self._tensor.handle

        self._post_start_barrier(avatar, real)
        self._post_start_op(avatar)

        # Post real op
        if self.parent is None:
            self._post_communication(self._avatar)

        else:
            self._post_communication(self._tensor)

        self._post_end_op(avatar)
        self._post_end_barrier(avatar, real)
Exemple #8
0
    def _partition_single_tensor(self, tensor, size):
        """Only partition a single tensor.

        Arguments:
            size: An integer. After partitioning, each tensor partition size must be equal or smaller than `size`.

        Returns:
            A list of partitioned tensors.
        """
        number = (tensor.size - 1) // size + 1
        if number > tensor.shape[0]:
            self._logger.warning(
                "The number of tensor rows (with shape {}) is smaller than partition number {}."
                .format(tensor.shape, number))
            number = tensor.shape[0]
        num_per_partition = tensor.shape[0] // number
        partitions_with_extra = tensor.shape[0] % number

        partitions = []
        start = 0
        end = num_per_partition
        for i in range(number):
            handle = NDArrayHandle()
            check_call(
                BYTESCHEDULER_LIB.bytescheduler_get_ndarray_avatar(
                    tensor.handle, byref(handle)))
            avatar = NDArray(handle)[start:end]
            partitions.append(avatar)
            start = end
            end += num_per_partition
            if i >= number - partitions_with_extra - 1:
                end += 1
        return partitions
def view_classify(img: NDArray, ps: NDArray, version="MNIST"):
    ''' Function for viewing an image and it's predicted classes.
    '''
    ps = ps.asnumpy().squeeze()

    fig, (ax1, ax2) = plt.subplots(figsize=(6, 9), ncols=2)
    ax1.imshow(img.asnumpy().squeeze())
    ax1.axis('off')
    ax2.barh(np.arange(10), ps)
    ax2.set_aspect(0.1)
    ax2.set_yticks(np.arange(10))
    if version == "MNIST":
        ax2.set_yticklabels(np.arange(10))
    ax2.set_title('Class Probability')
    ax2.set_xlim(0, 1.1)

    plt.tight_layout()
def view_recon(img: NDArray, recon):
    ''' Function for displaying an image (as a PyTorch Tensor) and its
        reconstruction also a PyTorch Tensor
    '''

    fig, axes = plt.subplots(ncols=2, sharex=True, sharey=True)
    axes[0].imshow(img.asnumpy().squeeze())
    axes[1].imshow(recon.asnumpy().squeeze())
    for ax in axes:
        ax.axis('off')
        ax.set_adjustable('box-forced')
Exemple #11
0
 def collect(self, name, arr):
     """Callback function for collecting layer output NDArrays."""
     name = py_str(name)
     if self.include_layer is not None and not self.include_layer(name):
         return
     handle = ctypes.cast(arr, NDArrayHandle)
     arr = NDArray(handle, writable=False).copyto(cpu())
     if self.logger is not None:
         self.logger.info("Collecting layer %s output of shape %s" %
                          (name, arr.shape))
     if name in self.nd_dict:
         self.nd_dict[name].append(arr)
     else:
         self.nd_dict[name] = [arr]
Exemple #12
0
 def _batchify(data: nd.NDArray, batch_size):
     """
     Make a batch tensor out of a vector
     :param data: vector
     :param batch_size: NN
     :return: (IN,NN) tensor
     """
     # Work out how cleanly we can divide the dataset into bsz parts.
     nbatch = len(data) // batch_size
     # Trim off any extra elements that wouldn't cleanly fit (remainders).
     data = data[0:nbatch * batch_size]
     # Evenly divide the data across the bsz batches.
     data = data.reshape(batch_size, -1).transpose()
     # if torch.cuda.is_available():
     #     data = data.cuda()
     return data
Exemple #13
0
 def collect(self, name, arr):
     """Callback function for collecting min and max values from an NDArray."""
     name = py_str(name)
     if self.include_layer is not None and not self.include_layer(name):
         return
     handle = ctypes.cast(arr, NDArrayHandle)
     arr = NDArray(handle, writable=False)
     min_range = ndarray.min(arr).asscalar()
     max_range = ndarray.max(arr).asscalar()
     if name in self.min_max_dict:
         cur_min_max = self.min_max_dict[name]
         self.min_max_dict[name] = (min(cur_min_max[0], min_range),
                                    max(cur_min_max[1], max_range))
     else:
         self.min_max_dict[name] = (min_range, max_range)
     if self.logger is not None:
         self.logger.info("Collecting layer %s min_range=%f, max_range=%f" %
                          (name, min_range, max_range))
Exemple #14
0
    def train(self, d: HybridSequential, x: NDArray, y: NDArray) -> float:
        #         with autograd.record():
        #             loss = (lambda y_hat: self.lossfun(1, d(concat(x, y_hat, dim=1)), y, y_hat))(self._network(x))
        with autograd.record():
            gen_out = self._network(x)

            y_hat = gen_out
            y = y

            #             x = x.repeat(int(y.shape[0]/x.shape[0]), 0)

            loss = self.lossfun(1, d(concat(x, y_hat, dim=1)),
                                y.reshape((-1, 3, 96, 96)),
                                y_hat.reshape((-1, 3, 96, 96)))

        loss.backward()
        self.trainer.step(1)

        return float(loss.asscalar())
Exemple #15
0
 def stat_helper(name, array):
     array = ctypes.cast(array, NDArrayHandle)
     array = NDArray(array, writable=False)
     if not self.activated or not self.re_prog.match(py_str(name)):
         return
     self.queue.append((self.step, py_str(name), stat(array)))
Exemple #16
0
def sq_loss(ey: nd.NDArray, y: nd.NDArray):
    ty = ey.reshape(shape=y.shape)
    loss = (ty - y)**2
    loss = loss.sum() / loss.size
    return loss
Exemple #17
0
    def _prepare(self):
        """Post start barrier OP, start OP, comm OP, end OP and end barrier OP to MXNet engine. The function of each
        kind of OP is explained below.
        start barrier OP: barrier the start of a parent ByteTask, used to maintain original dependency.
        start OP: It notifies Core about task readiness. It is also used to delay the start of a child ByteTask.
        comm OP: the OP that does real communication, e.g., push, pull, allreduce.
        end OP: an OP that runs after a child ByteTask is finished. It notifies Core about the task completion.
        end barrier OP: an OP that runs after the parent ByteTask is finished, used to maintain original dependency.

        Below are several key data structures.

        self._tensor: a list of NDArrays of the same key of all devices. If push_pull, self._tensor includes push list
        and pull list of NDArrays of all devices.
        real: the original handle list of self._tensor, used for keep dependency.
        avatar: a new handle list of self._tensor.
        """
        if self.parent is None:
            if self.op == "push_pull":
                push_real = [t.handle
                             for t in self._push_tensor] if isinstance(
                                 self._push_tensor,
                                 (tuple,
                                  list)) else [self._push_tensor.handle]
                pull_real = [t.handle
                             for t in self._pull_tensor] if isinstance(
                                 self._pull_tensor,
                                 (tuple,
                                  list)) else [self._pull_tensor.handle]
                assert len(push_real) == len(pull_real)
                real = push_real + pull_real
            else:
                real = [t.handle for t in self._tensor] if isinstance(
                    self._tensor, (tuple, list)) else [self._tensor.handle]
            avatar = []
            for h in real:
                avatar_h = NDArrayHandle()
                check_call(
                    BYTESCHEDULER_LIB.bytescheduler_get_ndarray_avatar(
                        h, byref(avatar_h)))
                avatar.append(avatar_h)
            if self.op == "push_pull":
                # push avatar and pull avatar NDArrays
                self._avatar = [[
                    NDArray(_) for _ in avatar[:int(len(avatar) / 2)]
                ], [NDArray(_) for _ in avatar[int(len(avatar) / 2):]]]
                avatar = [_.handle for _ in self._avatar[0]
                          ] + [_.handle for _ in self._avatar[1]]
            else:
                self._avatar = [NDArray(_) for _ in avatar]
                avatar = [_.handle for _ in self._avatar]
        else:
            if self.op == "push_pull":
                push_real = [
                    t.handle for t in self.parent._push_tensor
                ] if isinstance(self.parent._push_tensor,
                                (tuple,
                                 list)) else [self.parent._push_tensor.handle]
                pull_real = [
                    t.handle for t in self.parent._pull_tensor
                ] if isinstance(self.parent._pull_tensor,
                                (tuple,
                                 list)) else [self.parent._pull_tensor.handle]
                real = push_real + pull_real
                push_avatar = [t.handle
                               for t in self._push_tensor] if isinstance(
                                   self._push_tensor,
                                   (tuple,
                                    list)) else [self._push_tensor.handle]
                pull_avatar = [t.handle
                               for t in self._pull_tensor] if isinstance(
                                   self._pull_tensor,
                                   (tuple,
                                    list)) else [self._pull_tensor.handle]
                avatar = push_avatar + pull_avatar
            else:
                real = [t.handle for t in self.parent._tensor] if isinstance(
                    self.parent._tensor,
                    (tuple, list)) else [self.parent._tensor.handle]
                avatar = [t.handle for t in self._tensor] if isinstance(
                    self._tensor, (tuple, list)) else [self._tensor.handle]

        self._post_start_barrier(avatar, real)
        self._post_start_op(avatar)
        self._post_push_pull_barrier(avatar)

        # post real op
        if self.parent is None:
            self._post_communication(self._avatar)
        else:
            self._post_communication(self._tensor)

        self._post_end_op(avatar)

        self._post_end_barrier(avatar, real)
Exemple #18
0
 def _repackage_hidden(h: nd.NDArray):
     """Wraps hidden states in new Variables, to detach them from their history."""
     return h.detach()
Exemple #19
0
def argmax_batch(vecs: nd.NDArray):
    # _, idx = torch.max(vecs, 1)
    # return idx
    return vecs.argmax(axis=1)
Exemple #20
0
def argmax(vec: nd.NDArray):
    # _, idx = torch.max(vec, 1)
    # return to_scalar(idx)
    return vec.argmax(1)
def def_grads(prims):
    """ Define gradient function for primitives """
    identity = lambda x: x
    # dot
    prims('dot').def_grad(
        lambda ans, a, b: lambda g: mx.nd.dot(g, b, transpose_b=True))
    prims('dot').def_grad(
        lambda ans, a, b: lambda g: mx.nd.dot(a, g, transpose_a=True),
        argnum=1)
    # non-linear
    prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans**2))
    prims('exp').def_grad(lambda ans, x: lambda g: g * ans)
    prims('log').def_grad(lambda ans, x: lambda g: g / x)
    # reduce
    prims('sum').def_grad(_reduce_sum_grad_gen)
    prims('max').def_grad(_reduce_select_grad_gen)
    prims('min').def_grad(_reduce_select_grad_gen)
    # + - * /
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y))
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1)
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, y, identity),
                          argnum=1)
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1)
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('maximum').def_grad(_selection_grad_gen0)
    prims('maximum').def_grad(_selection_grad_gen1, argnum=1)
    prims('minimum').def_grad(_selection_grad_gen0)
    prims('minimum').def_grad(_selection_grad_gen1, argnum=1)
    # negate
    prims('negative').def_grad(lambda ans, x: operator.neg)
    prims('transpose').def_grad(lambda ans, x: mx.nd.transpose)
    prims('abs').def_grad(lambda ans, x: lambda g: mx.nd.sign(x) * g)
    prims('sign').def_grad_zero()
    prims('round').def_grad_zero()
    prims('ceil').def_grad_zero()
    prims('floor').def_grad_zero()
    prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mx.nd.sqrt(x))
    prims('sin').def_grad(lambda ans, x: lambda g: g * mx.nd.cos(x))
    prims('cos').def_grad(lambda ans, x: lambda g: -g * mx.nd.sin(x))
    prims('power').def_grad(lambda ans, x, y: _unbroadcast(
        ans, x, lambda g: g * y * mx.nd.power(x, y - 1)))
    prims('power').def_grad(lambda ans, x, y: _unbroadcast(
        ans, y, lambda g: g * mx.nd.log(x) * ans),
                            argnum=1)
    prims('reshape').def_grad(
        lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape))
    prims('expand_dims').def_grad(
        lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
    prims('softmax_output').def_grad(_softmax_output_grad)