def __get_argsort_indices(self, a, axis): """ Calculates indices which can be used to reverse sorting operation of "a" tensor along "axis". Returns ------- 1d array if axis is None list of length len(a.shape) otherwise """ # The goal is to get gradient wrt input from gradient # wrt sort(input, axis) idx = argsort(a, axis, kind=self.kind, order=self.order) # rev_idx is the reverse of previous argsort operation rev_idx = argsort(idx, axis, kind=self.kind, order=self.order) indices = [] axis_data = switch(ge(axis.data, 0), axis.data, a.ndim + axis.data) for i in range(a.ndim): index_val = switch( eq(i, axis_data), rev_idx, self.__get_expanded_dim(a, axis, i), ) indices.append(index_val) return indices
def grad(self, inp, cost_grad): """ Notes ----- The gradient is currently implemented for matrices only. """ a, val, offset = inp grad = cost_grad[0] height, width = grad.shape if a.dtype.startswith("complex"): return [None, None] # only valid for matrices wr_a = fill_diagonal_offset(grad, 0, offset) offset_abs = abs_(offset) pos_offset_flag = ge(offset, 0) neg_offset_flag = lt(offset, 0) min_wh = minimum(width, height) start = offset * pos_offset_flag + offset_abs * width * neg_offset_flag num_of_step = minimum( min_wh, width * pos_offset_flag + height * neg_offset_flag - offset_abs) step = a.shape[1] + 1 end = start + step * num_of_step # input of slice should be integer start = aet.cast(start, "int32") step = aet.cast(step, "int32") end = aet.cast(end, "int32") wr_val = grad.flatten()[start:end:step].sum() wr_offset = grad_undefined( self, 2, offset, "offset is not defined for non-integer offset so" " fill_diagonal_offset(a,val,offset+eps) is undefined", ) return [wr_a, wr_val, wr_offset]