Esempio n. 1
0
    def __get_argsort_indices(self, a, axis):
        """
        Calculates indices which can be used to reverse sorting operation of
        "a" tensor along "axis".

        Returns
        -------
        1d array if axis is None
        list of length len(a.shape) otherwise

        """

        # The goal is to get gradient wrt input from gradient
        # wrt sort(input, axis)
        idx = argsort(a, axis, kind=self.kind, order=self.order)
        # rev_idx is the reverse of previous argsort operation
        rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
        indices = []
        axis_data = switch(ge(axis.data, 0), axis.data, a.ndim + axis.data)
        for i in range(a.ndim):
            index_val = switch(
                eq(i, axis_data),
                rev_idx,
                self.__get_expanded_dim(a, axis, i),
            )
            indices.append(index_val)
        return indices
Esempio n. 2
0
    def grad(self, inp, cost_grad):
        """
        Notes
        -----
        The gradient is currently implemented for matrices only.
        """
        a, val, offset = inp
        grad = cost_grad[0]
        height, width = grad.shape

        if a.dtype.startswith("complex"):
            return [None, None]

        # only valid for matrices
        wr_a = fill_diagonal_offset(grad, 0, offset)

        offset_abs = abs_(offset)
        pos_offset_flag = ge(offset, 0)
        neg_offset_flag = lt(offset, 0)
        min_wh = minimum(width, height)

        start = offset * pos_offset_flag + offset_abs * width * neg_offset_flag
        num_of_step = minimum(
            min_wh,
            width * pos_offset_flag + height * neg_offset_flag - offset_abs)

        step = a.shape[1] + 1
        end = start + step * num_of_step

        # input of slice should be integer
        start = aet.cast(start, "int32")
        step = aet.cast(step, "int32")
        end = aet.cast(end, "int32")

        wr_val = grad.flatten()[start:end:step].sum()

        wr_offset = grad_undefined(
            self,
            2,
            offset,
            "offset is not defined for non-integer offset so"
            " fill_diagonal_offset(a,val,offset+eps) is undefined",
        )

        return [wr_a, wr_val, wr_offset]