def softmax(x, name="softmax", axis=0, frontend='keras'): shape = x.shape k = hcl.reduce_axis(0, shape[axis]) new_shape = [] for i in range(len(shape)): if i != axis: new_shape.append(shape[i]) def _reduce_axis(axis, new_axis, keep_axis, *indices): indices = indices[0] new_ind = [] put_axis = False for i in range(len(indices)): if i == axis and keep_axis: new_ind.append(new_axis) put_axis = True new_ind.append(indices[i]) elif i != axis: new_ind.append(indices[i]) if put_axis == False and keep_axis: new_ind.append(new_axis) return tuple(new_ind) max_elem = hcl.compute( tuple(new_shape), lambda *y: max(x[_reduce_axis(axis, k, True, y)], axis=[k])) k = hcl.reduce_axis(0, shape[axis]) expsum = hcl.compute( tuple(new_shape), lambda *y: sum( tvm.exp(x[_reduce_axis(axis, k, True, y)] - max_elem[y]), axis=k)) return hcl.compute( x.shape, lambda *y: tvm.exp(x[y] - max_elem[_reduce_axis( axis, k, False, y)]) / expsum[_reduce_axis(axis, k, False, y)], name)
def softmax(out, x): assert len(x.shape) == 2, "only support 2-dim softmax" m, n = x.shape k = hcl.reduce_axis(0, n) max_elem = hcl.compute((m, ), lambda i: max(x[i, k], axis=k)) k = hcl.reduce_axis(0, n) expsum = hcl.compute((m, ), lambda i: sum(tvm.exp(x[i, k] - max_elem[i]), axis=k)) return hcl.update(out, lambda i, j: tvm.exp(x[i, j] - max_elem[i]) / expsum[i])