def SQSOneStep(reconNet, x, z, ref, prj, weight, normImg, projectorNorm, args, verbose=0): # projection term if not reconNet.rotview % args.nSubsets == 0: raise ValueError('reconNet.rotview cannot be divided by args.nSubsets') inds = helper.OrderedSubsetsBitReverse(reconNet.rotview, args.nSubsets) angles = np.array([reconNet.angles[i] for i in inds], np.float32) prj = prj[:, :, inds, :] weight = weight[:, :, inds, :] nAnglesPerSubset = int(reconNet.rotview / args.nSubsets) x_new = np.copy(z) for i in range(args.nSubsets): if verbose: print('set%d' % i, end=',', flush=True) curAngles = angles[i * nAnglesPerSubset:(i + 1) * nAnglesPerSubset] curWeight = weight[:, :, i * nAnglesPerSubset:(i + 1) * nAnglesPerSubset, :] fp = reconNet.cDDFanProjection3d(x_new, curAngles) / projectorNorm dprj = fp - prj[:, :, i * nAnglesPerSubset: (i + 1) * nAnglesPerSubset, :] / projectorNorm bp = reconNet.cDDFanBackprojection3d(curWeight * dprj, curAngles) / projectorNorm # tvs1, tvs2, _ = PriorFunctionSolver.cTVSQS2D(x_new, args.eps) sqsNlm = 4 * ( x_new - HYPR_NLM.NLM(x_new, ref, args.searchSize, args.kernelSize, args.kernelStd, args.sigma)) x_new = x_new - (args.nSubsets * bp + args.betaRecon * sqsNlm) / ( normImg + args.betaRecon * 8) z = x_new + args.nesterov * (x_new - x) x = np.copy(x_new) # get loss function fp = reconNet.cDDFanProjection3d(x, angles) / projectorNorm dataLoss = 0.5 * np.sum(weight * (fp - prj / projectorNorm)**2) # nlm loss nlm = HYPR_NLM.NLM(x, ref, args.searchSize, args.kernelSize, args.kernelStd, args.sigma) nlm2 = HYPR_NLM.NLM(x**2, ref, args.searchSize, args.kernelSize, args.kernelStd, args.sigma) nlmLoss = np.sum(x**2 - 2 * x * nlm + nlm2) return x, z, dataLoss, nlmLoss
def SQSOneStep(reconNet, x, x_nesterov, z, masks, prj, weight, normImg, projectorNorm, args, verbose=0): # projection term if not reconNet.rotview % args.nSubsets == 0: raise ValueError('reconNet.rotview cannot be divided by args.nSubsets') inds = helper.OrderedSubsetsBitReverse(reconNet.rotview, args.nSubsets) angles = np.array([reconNet.angles[i] for i in inds], np.float32) prj = prj[:, :, inds, :] weight = weight[:, :, inds, :] nAnglesPerSubset = int(reconNet.rotview / args.nSubsets) x_new = np.copy(x_nesterov) for i in range(args.nSubsets): if verbose: print('set%d' % i, end=',', flush=True) curAngles = angles[i * nAnglesPerSubset:(i + 1) * nAnglesPerSubset] curWeight = weight[:, :, i * nAnglesPerSubset:(i + 1) * nAnglesPerSubset, :] fp = reconNet.cDDFanProjection3d(x_new, curAngles) / projectorNorm dprj = fp - prj[:, :, i * nAnglesPerSubset: (i + 1) * nAnglesPerSubset, :] / projectorNorm bp = reconNet.cDDFanBackprojection3d(curWeight * dprj, curAngles) / projectorNorm x_new = x_new - (args.nSubsets * bp + args.gamma * args.betaRecon * (x_new - z) * masks) / ( normImg + args.gamma * args.betaRecon * masks) x_nesterov = x_new + args.nesterov * (x_new - x) x = np.copy(x_new) # get loss function fp = reconNet.cDDFanProjection3d(x, angles) / projectorNorm dataLoss = np.sum(weight * (fp - prj / projectorNorm)**2) return x, x_nesterov, dataLoss