Пример #1
0
    def __init__(self,
                 x,
                 y,
                 filter_regs,
                 precond,
                 sample_weights,
                 net,
                 pixel_weighting,
                 compute_norm=False):
        super().__init__()

        self.training_samples = x
        self.training_labels = y
        self.y_size = y.shape[-2:]

        self.x = x
        self.y = y
        self.w = pixel_weighting

        self.filter_regs = TensorList(filter_regs)
        if sample_weights.size()[0] == 1:
            self.sample_weights = x.new_zeros(x.shape[0])
            self.sample_weights.fill_(sample_weights[0])
        else:
            self.sample_weights = sample_weights

        self.diag_M = TensorList(precond)
        self.pixel_weighting = pixel_weighting

        self.net = net
        self.compute_norm = compute_norm
Пример #2
0
def empirical_estimate(points, num_neighbors):
    ps, batch = points.permute(1, 0).cat_tensors()

    N = ps.shape[0]
    val = knn(ps.contiguous(),
              ps.contiguous(),
              batch_x=batch,
              batch_y=batch,
              k=num_neighbors)
    A = ps[val[1, :]].reshape(N, num_neighbors, 3)
    A = A - A.mean(dim=1, keepdim=True)
    Asqr = A.permute(0, 2, 1).bmm(A)
    sigma, _ = Asqr.cpu().symeig()
    w = (sigma[:, 2] * sigma[:, 1]).sqrt()
    val = val[1, :].reshape(N, num_neighbors)
    w, _ = torch.median(w[val].to(ps.device), dim=1, keepdim=True)

    weights = TensorListList()
    bi = 0
    for point_list in points:
        ww = TensorList()
        for p in point_list:
            ww.append(w[batch == bi])
            bi = bi + 1

        weights.append(ww)

    return weights
Пример #3
0
def GARPointList(t1_list, t2_list, R1_list, R2_list, V, c, alpha):
    err = TensorListList()

    for t1, t2, R1, R2, Vb in zip(t1_list, t2_list, R1_list, R2_list, V):
        M = len(t1)
        err1 = TensorList()
        for ind1 in range(M - 1):
            for ind2 in range(ind1 + 1, M):
                Rtest = R1[ind1].permute(1, 0) @ R1[ind2]
                # estimated relative translation from ind2 to ind1
                test = R1[ind1].permute(1, 0) @ t1[ind2] - R1[ind1].permute(
                    1, 0) @ t1[ind1]

                Rgt = R2[ind1].permute(1, 0) @ R2[ind2]
                # ground truth translation from ind2 to ind1
                tgt = R2[ind1].permute(1, 0) @ t2[ind2] - R2[ind1].permute(
                    1, 0) @ t2[ind1]

                diff = Rtest @ Vb[ind2] + test - (Rgt @ Vb[ind2] + tgt)

                err_i = (diff * diff).sum(dim=-2).sqrt()
                err1.append(lossfun(err_i, alpha=alpha, scale=c).mean())

        err.append(err1)

    return err
Пример #4
0
    def initialize_mu(self, K, features):
        X = TensorList()
        for TV in features:
            Xi = np.random.randn(TV[0].shape[0], K).astype(np.float32)
            Xi = torch.from_numpy(Xi).to(TV[0].device)
            Xi = Xi / torch.norm(Xi, dim=0, keepdim=True)
            X.append(Xi.permute(1, 0))

        return X
Пример #5
0
def L2sqrList(R1_list, R2_list):
    err = TensorListList()
    for R1, R2 in zip(R1_list, R2_list):
        M = len(R1)
        err1 = TensorList()
        for ind1 in range(M - 1):
            for ind2 in range(ind1 + 1, M):
                Rdiff = R1[ind1].permute(1, 0) @ R1[ind2] - R2[ind1].permute(
                    1, 0) @ R2[ind2]
                err1.append(
                    (Rdiff * Rdiff).sum(dim=-1).sum(dim=-1).unsqueeze(0))
        err.append(err1)

    return err
Пример #6
0
 def A(self, x):
     dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                  self.g,
                                  x,
                                  retain_graph=True)
     return TensorList(
         torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))
Пример #7
0
def get_init_transformation_list(pcds, mean_init=True):
    if mean_init:
        m_pcds = pcds.mean(dim=-1, keepdim=True)
        m_target = TensorList([m[0] for m in m_pcds])
        init_t = -(m_pcds - m_target)
    else:
        tt = []
        for i in range(len(pcds)):
            tt.append(TensorList([torch.zeros(3, 1).to(pcds[0][0].device) for i in range(len(pcds[i]))]))
        init_t = TensorListList(tt)

    rr = []
    for i in range(len(pcds)):
        rr.append(TensorList([torch.eye(3, 3).to(pcds[0][0].device) for i in range(len(pcds[i]))]))

    init_R = TensorListList(rr)
    return init_R, init_t
Пример #8
0
    def run_GN_iter(self, num_cg_iter):

        self.x.requires_grad_(True)

        self.f0 = self.problem(self.x)
        self.g = self.f0.detach()
        self.g.requires_grad_(True)
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g,
                                create_graph=True))  # df/dx^t @ f0
        self.b = -self.dfdxt_g.detach()

        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += self.step_alpha * delta_x
        self.step_alpha = min(self.step_alpha * 1.2, 1.0)
Пример #9
0
    def cluster_features(self, features, num_clusters):
        feature_labels_LL = TensorListList()
        for f in features:
            feature_labels_L = TensorList()
            fcat = torch.cat([fi for fi in f])
            fcat = fcat.to("cpu").numpy()
            kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(fcat)
            labels = torch.from_numpy(kmeans.labels_)
            onehot = torch.nn.functional.one_hot(
                labels.long(), num_clusters).to(self.params.device)
            cnt = 0
            for fi in f:
                feature_labels_L.append(onehot[cnt:cnt + fi.shape[0]])
                cnt += fi.shape[0]

            feature_labels_LL.append(feature_labels_L.permute(1, 0).float())

        return feature_labels_LL
Пример #10
0
def L2sqrTransList(t1_list, t2_list, R1_list, R2_list):
    err = TensorListList()
    for t1, t2, R1, R2 in zip(t1_list, t2_list, R1_list, R2_list):
        M = len(t1)
        err1 = TensorList()
        for ind1 in range(M - 1):
            for ind2 in range(ind1 + 1, M):
                # estimated relative translation from ind2 to ind1
                test = R1[ind1].permute(1, 0) @ t1[ind2] - R1[ind1].permute(
                    1, 0) @ t1[ind1]
                # ground truth translation from ind2 to ind1
                tgt = R2[ind1].permute(1, 0) @ t2[ind2] - R2[ind1].permute(
                    1, 0) @ t2[ind1]

                diff = test - tgt
                err1.append((diff * diff).sum(dim=-2))
        err.append(err1)
    return err
Пример #11
0
def extract_correspondences_gpu(coords, Rs, ts, device=None):
    coords = Rs @ coords + ts

    M = len(coords)
    ind_nns = []
    dists = []
    if not device is None:
        coords = coords.to(device)
    for ind1 in range(M - 1):
        for ind2 in range(ind1 + 1, M):
            point1 = coords[ind1].permute(1, 0)
            point2 = coords[ind2].permute(1, 0)
            inds_nn = knn(point2, point1, 1)
            d = point1[inds_nn[0, :]] - point2[inds_nn[1, :]]
            d = (d * d).sum(dim=1).sqrt()
            dists.append(d)
            ind_nns.append(inds_nn)

    return dict(indices=TensorList(ind_nns), distances=TensorList(dists))
Пример #12
0
def extract_correspondences(coords, Rs, ts):
    coords = Rs @ coords + ts

    M = len(coords)
    ind_nns = []
    dists = []
    for ind1 in range(M - 1):
        for ind2 in range(ind1 + 1, M):
            point1 = coords[ind1].cpu().numpy().T
            point2 = coords[ind2].cpu().numpy().T
            tree = KDTree(point1)
            d, inds_nn1 = tree.query(point2, k=1)
            inds_nn2 = torch.tensor([n for n in range(0, point2.shape[0])])
            inds_nn1 = torch.from_numpy(inds_nn1)

            ind_nns.append(torch.stack([inds_nn1, inds_nn2], dim=0))
            dists.append(torch.from_numpy(d))

    return dict(indices=TensorList(ind_nns).to(coords[0][0].device),
                distances=TensorList(dists).to(coords[0][0].device))
Пример #13
0
    def forward(self, x):
        coords = x['coordinates'].clone()
        t_tot = time.time()

        features_LL = TensorListList()
        att_LL = TensorListList()
        coords_ds_LL = TensorListList()
        inds_LL = []
        ind_cnt = 0
        for coords_b in coords:
            features_L = TensorList()
            att_L = TensorList()
            coords_ds_L = TensorList()
            inds_L = []
            for coords_s in coords_b:
                pcd_down, f = self.preprocess_point_cloud(
                    coords_s, self.params.voxel_size)

                features_L.append(f)
                coords_ds_L.append(pcd_down)

                ind_cnt = ind_cnt + 1

            features_LL.append(features_L)
            att_LL.append(att_L)
            coords_ds_LL.append(coords_ds_L)
            inds_LL.append(inds_L)

        x = dict()
        x['features'] = self.cluster_features(
            features_LL,
            self.params.feature_distr_parameters.num_feature_clusters)
        x['att'] = att_LL
        x['coordinates_ds'] = coords_ds_LL.to(self.params.device)
        x['indices'] = inds_LL

        out = self.registration(x)
        tot_time = time.time() - t_tot
        print("tot time: %.1f ms" % (tot_time * 1000))

        out["time"] = tot_time
        out["features"] = features_LL
        out["indices"] = inds_LL
        out["coordinates_ds"] = coords_ds_LL

        return out
Пример #14
0
def estimate_overlap(ps1, ps2, R1, R2, t1, t2, voxel_size, device):
    # voxel grid downsampling
    quantized_coords1 = torch.floor(ps1.permute(1, 0) / voxel_size)
    inds1 = ME.utils.sparse_quantize(quantized_coords1, return_index=True)
    quantized_coords2 = torch.floor(ps2.permute(1, 0) / voxel_size)
    inds2 = ME.utils.sparse_quantize(quantized_coords2, return_index=True)

    ps1_v = ps1[:, inds1]
    ps2_v = ps2[:, inds2]

    corrs = extract_correspondences_gpu(TensorList([ps1_v, ps2_v]),
                                        TensorList([R1, R2]),
                                        TensorList([t1, t2]), device)

    d = corrs["distances"]
    rates = []
    thresh = voxel_size
    for di in d:
        rate = float((di < thresh).sum()) / di.shape[0]
        rates.append(rate)

    return rates
Пример #15
0
def get_init_transformation_list_dgr(pcds, dgr_init_model):
    if isinstance(pcds, TensorListList):
        Rs = TensorListList()
        ts = TensorListList()
    else:
        # B, P, N, M = point_clouds.shape
        # assert P == 2
        Rs = []
        ts = []

    target_R = torch.eye(3, 3).to(pcds[0][0].device)
    target_t = torch.zeros(3, 1).to(pcds[0][0].device)
    for pc in pcds:
        assert len(pc) == 2
        R,t=dgr_init_model.register_point_sets(pc[0].permute(1,0), pc[1].permute(1,0))

        if isinstance(pcds, TensorListList):
            Rs.append(TensorList([R, target_R]))
            ts.append(TensorList([t.unsqueeze(dim=1), target_t]))
        else:
            Rs.append(torch.stack([R, target_R]))
            ts.append(torch.stack([t.unsqueeze(dim=1), target_t]))

    return Rs, ts
Пример #16
0
def collate_tensorlist(batch):
    out = dict()
    info = dict()
    tensorlistkeysout = []
    tensorlistkeysinfo = []
    for b in batch:
        cb, ib = b
        for k, v in cb.items():
            if isinstance(v, TensorList):
                if not k in out:
                    out[k] = TensorList()
                    tensorlistkeysout.append(k)
            else:
                if not k in out:
                    out[k] = []

            out[k].append(v)

        for k, v in ib.items():
            if isinstance(v, TensorList):
                if not k in info:
                    info[k] = TensorList()
                    tensorlistkeysinfo.append(k)
            else:
                if not k in info:
                    info[k] = []

            info[k].append(v)

    for k in tensorlistkeysout:
        out[k] = TensorListList(out[k])

    for k in tensorlistkeysinfo:
        info[k] = TensorListList(info[k])

    return out, info
Пример #17
0
    def initialize(self, K, features):
        init_method = self.params.get("init_method", "dirichlet")
        if init_method == "uniform":
            init_distr_list = []
            for i in range(len(features)):
                init_distr_list.append(
                    torch.ones(K, self.params.num_feature_clusters).float() /
                    self.params.num_feature_clusters)
        elif init_method == "dirichlet":
            init_distr_list = TensorList()
            relative_distr = features.sum(dim=1).sum_list()
            relative_distr = relative_distr / relative_distr.sum()
            for r in relative_distr:
                dir = torch.distributions.dirichlet.Dirichlet(r)
                init_distr_list.append(dir.sample((K, )))
        else:
            init_distr_list = []
            for i in range(len(features)):
                init_distr_list.append(
                    torch.ones(K, self.params.num_feature_clusters).float() /
                    self.params.num_feature_clusters)

        self.distr = TensorListList(init_distr_list, repeat=self.repeat_list)
Пример #18
0
    def __call__(self, parameters: TensorList):

        s = self.net(self.x)
        s = F.interpolate(s, self.y_size, mode='bilinear', align_corners=False)
        residuals = self.w * (s - self.y)
        return TensorList([residuals, *(self.filter_regs * parameters)])
Пример #19
0
    def init(self, x, y):
        """
        :param x: Tensor of data augmented features from the first frame, shape (K, Cf, Hf, Wf),
                  where K is the number of augmented feature maps.
        :param y: Object mask tensor, shape (K, 1, Him, Wim)
        """

        pw = self.compute_pixel_weights(y)

        # Run the initial optimization

        memory = Memory(y.shape[0], x.shape[-3:], y.shape[-3:], self.device,
                        self.learning_rate)
        memory.initialize(x, y, pw)

        parameters = TensorList([self.project.weight, self.filter.weight])
        problem = DiscriminatorLoss(x=memory.samples,
                                    y=memory.labels,
                                    filter_regs=self.filter_reg,
                                    precond=self.precond,
                                    sample_weights=memory.weights,
                                    net=nn.Sequential(self.project,
                                                      self.filter),
                                    pixel_weighting=memory.pixel_weights)
        optimizer = GaussNewtonCG(
            problem,
            parameters,
            fletcher_reeves=False,
            standard_alpha=True,
            direction_forget_factor=self.direction_forget_factor)
        problem.net.train()
        optimizer.run(self.init_iters)
        problem.net.eval()

        x = self.project(
            x)  # Re-project samples with the new projection matrix

        # Initialize the memory

        memory = Memory(self.memory_size, x.shape[-3:], y.shape[-3:],
                        self.device, self.learning_rate)
        memory.initialize(x, y, pw)

        # Build the update problem

        parameters = TensorList([self.filter.weight])
        problem = DiscriminatorLoss(x=memory.samples,
                                    y=memory.labels,
                                    filter_regs=self.filter_reg[1:],
                                    precond=self.precond[1:],
                                    sample_weights=memory.weights,
                                    net=self.filter,
                                    pixel_weighting=memory.pixel_weights)
        optimizer = GaussNewtonCG(
            problem,
            parameters,
            fletcher_reeves=False,
            standard_alpha=True,
            direction_forget_factor=self.direction_forget_factor)
        problem.net.train()
        optimizer.run(self.update_iters)
        problem.net.eval()

        self.memory = memory
        self.update_optimizer = optimizer
Пример #20
0
    def register_point_sets(self, x):
        point_clouds = x["coordinates"].clone()
        t_tot = time.time()

        if isinstance(point_clouds, TensorListList):
            Rs = TensorListList()
            ts = TensorListList()
        else:
            #B, P, N, M = point_clouds.shape
            #assert P == 2
            Rs = []
            ts = []

        target_R = torch.eye(3, 3)
        target_t = torch.zeros(3, 1)
        for pcds in point_clouds:
            if self.params.get("mean_init", True):
                m_target = pcds[1].mean(dim=1)
                m_source = pcds[0].mean(dim=1)
                init_t = -(m_target - m_source).cpu().numpy()
            else:
                init_t = [0, 0, 0]

            trans_init = np.asarray([[1., 0., 0., init_t[0]],
                                     [0., 1., 0., init_t[1]],
                                     [0., 0., 1., init_t[2]],
                                     [0.0, 0.0, 0.0, 1.0]])

            source = o3d.geometry.PointCloud()
            target = o3d.geometry.PointCloud()
            source.points = o3d.utility.Vector3dVector(pcds[0].cpu().numpy().T)
            target.points = o3d.utility.Vector3dVector(pcds[1].cpu().numpy().T)

            voxel_size = self.params.voxel_size
            max_correspondence_distance = self.params.threshold
            radius_normal = float(self.params.radius_normal)

            if self.params.metric == "p2pl":
                source.estimate_normals(
                    search_param=o3d.geometry.KDTreeSearchParamHybrid(
                        radius=radius_normal, max_nn=30))
                target.estimate_normals(
                    search_param=o3d.geometry.KDTreeSearchParamHybrid(
                        radius=radius_normal, max_nn=30))
                # o3d.geometry.estimate_normals(source)
                # o3d.geometry.estimate_normals(target)

            if self.params.get("downsample", True):
                source = source.voxel_down_sample(voxel_size=voxel_size)
                target = target.voxel_down_sample(voxel_size=voxel_size)

            reg_p2p = o3d.registration.registration_icp(
                source,
                target,
                max_correspondence_distance=max_correspondence_distance,
                estimation_method=self.metric,
                init=trans_init)

            T = reg_p2p.transformation
            R = T[0:3, 0:3]
            t = T[0:3, 3:]

            if isinstance(point_clouds, TensorListList):
                Rs.append(TensorList([torch.from_numpy(R).float(), target_R]))
                ts.append(TensorList([torch.from_numpy(t).float(), target_t]))
            else:
                Rs.append(torch.stack([torch.from_numpy(R).float(), target_R]))
                ts.append(torch.stack([torch.from_numpy(t).float(), target_t]))

        tot_time = time.time() - t_tot
        print("ICP tot time: %.1f ms" % (tot_time * 1000))

        if isinstance(point_clouds, TensorListList):
            return dict(Rs=Rs, ts=ts, Vs=Rs @ point_clouds + ts, time=tot_time)
        else:
            return dict(Rs=torch.stack(Rs),
                        ts=torch.stack(ts),
                        Vs=point_clouds,
                        time=tot_time)
Пример #21
0
    def forward(self, x_in):
        coords = x_in['coordinates'].clone()
        t_tot = time.time()
        sinput, inds_list = self.create_sparse_tensors(coords)
        resample_time = time.time() - t_tot
        print("resample time: %.1f ms" % ((resample_time) * 1000))
        if not self.params.feature_distr_parameters.model=='none':
            features_dict = self.feature_extractor(sinput)
            features = features_dict["features"]
            if "attention" in features_dict.keys():
                att = features_dict["attention"]

            extract_time=time.time()-t_tot
            print("extract time: %.1f ms" % ((extract_time) * 1000))
            batch_indices = list(features.coords_man.get_batch_indices())
        else:
            features_dict=None
            extract_time=0

        time_preprocess = time.time()

        features_LL = TensorListList()
        att_LL = TensorListList()
        coords_ds_LL = TensorListList()
        inds_LL = []
        ind_cnt = 0
        for coords_b in coords:
            features_L = TensorList()
            att_L = TensorList()
            coords_ds_L = TensorList()
            inds_L = []
            for coords_s in coords_b:
                if not features_dict is None:
                    mask = features.C[:, 0] == batch_indices[ind_cnt]

                    # hacky way of finding batch channel dimension
                    if mask.int().sum() != inds_list[ind_cnt].shape[0]:
                        mask = features.C[:, -1] == batch_indices[ind_cnt]

                    f = features.F[mask]
                    assert f.shape[0] == inds_list[ind_cnt].shape[0]
                    if "attention" in features_dict.keys():
                        a = att.F[mask]
                        assert a.shape[0] == inds_list[ind_cnt].shape[0]
                        att_L.append(a)

                    features_L.append(f.permute(1, 0))

                coords_ds_L.append(coords_s[:, inds_list[ind_cnt]])
                inds_L.append(inds_list[ind_cnt])

                ind_cnt = ind_cnt + 1

            features_LL.append(features_L)
            att_LL.append(att_L)
            coords_ds_LL.append(coords_ds_L)
            inds_LL.append(inds_L)

        x = dict()
        x['features'] = features_LL
        x['att'] = att_LL
        x['coordinates_ds'] = coords_ds_LL.to(self.params.device)
        x['indices'] = inds_LL
        x['coordinates'] = x_in['coordinates']

        print("preprocess time: %.1f ms" % ((time.time()-time_preprocess) * 1000))

        reg_time=time.time()
        out = self.registration(x)
        reg_time2 = time.time() - reg_time
        print("reg time: %.1f ms" % ((time.time() - reg_time) * 1000))
        tot_time = time.time() - t_tot
        print("tot time: %.1f ms" % (tot_time * 1000))

        out["time"] = tot_time
        out["reg_time"] = reg_time2
        out["extract_time"] =extract_time
        out["resample_time"] = resample_time
        out["coordinates_ds"] = coords_ds_LL

        return out
Пример #22
0
    def read_sample(self, specs):
        coords = []
        R_gt = []
        t_gt = []
        R_init = []
        t_init = []
        time_stamps = []
        names = []

        if self.parameters.with_features:
            features = []

        for name in specs.view_names:
            names.append(specs["sequence"] + "/" + name)

            data = self.loader(name, specs)
            data = self.downsampler(data, self.parameters.with_features)
            data = self.preprocesser(data, self.parameters.with_features)

            if self.parameters.augment == True:
                ## generate random transformation as ground truths
                r = torch.rand(1)
                ax = torch.rand(3)
                direction = torch.rand(3)
                t_scale = torch.rand(1)
                R = get_rotation_matrix(self.parameters.ang_range, r=r, ax=ax)
                t = get_translation_vector(self.parameters.t_range * t_scale,
                                           direction=direction)
                # transform coordinates
                data['coordinates'] = R.t() @ data['coordinates'] - R.t(
                ) @ t  # transform with inverse augmented ground truth
                R_gt.append(data['R_gt'] @ R)
                t_gt.append(data['R_gt'] @ t + data['t_gt'])
            else:
                R_gt.append(data['R_gt'])
                t_gt.append(data['t_gt'])

            R_init.append(torch.eye(3, 3))
            t_init.append(torch.zeros(3, 1))

            coords.append(data['coordinates'])
            if self.parameters.with_features:
                features.append(data['features'])

            if "time_stamps" in data.keys():
                time_stamps.append(data["time_stamps"])

        out = {'coordinates': TensorList(coords)}

        info = {
            'R_gt': TensorList(R_gt),
            't_gt': TensorList(t_gt),
            'R_init': TensorList(R_init),
            't_init': TensorList(t_init),
            'sequence': specs.sequence,
            'names': names
        }

        if self.parameters.with_features:
            out['features'] = TensorList(features)

        if self.parameters.with_correspondences:
            info['correspondences'] = self.extract_correspondences(out, info)

        if len(time_stamps) > 0:
            out['time_stamps'] = TensorList(time_stamps)

        return out, info
Пример #23
0
class GaussNewtonCG:
    def __init__(self,
                 problem: MinimizationProblem,
                 variable: TensorList,
                 cg_eps=0.0,
                 fletcher_reeves=True,
                 standard_alpha=True,
                 direction_forget_factor=0,
                 step_alpha=1.0):

        self.fletcher_reeves = fletcher_reeves
        self.standard_alpha = standard_alpha
        self.direction_forget_factor = direction_forget_factor

        # State
        self.p = None
        self.rho = torch.ones(1)
        self.r_prev = None

        # Right hand side
        self.b = None

        self.problem = problem
        self.x = variable

        self.cg_eps = cg_eps
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

        self.residuals = torch.zeros(0)
        self.external_losses = []
        self.internal_losses = []
        self.gradient_mags = torch.zeros(0)

        self.step_alpha = step_alpha

    def clear_temp(self):
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

    def run(self, num_cg_iter, num_gn_iter=None):

        self.problem.initialize()

        if isinstance(num_cg_iter, int):
            if num_gn_iter is None:
                raise ValueError(
                    'Must specify number of GN iter if CG iter is constant')
            num_cg_iter = [num_cg_iter] * num_gn_iter

        num_gn_iter = len(num_cg_iter)
        if num_gn_iter == 0:
            return

        # with torch.autograd.profiler.profile(use_cuda=True) as prof:
        for cg_iter in num_cg_iter:
            self.run_GN_iter(cg_iter)

        self.x.detach_()
        self.clear_temp()

        return self.external_losses, self.internal_losses, self.residuals

    def run_GN_iter(self, num_cg_iter):

        self.x.requires_grad_(True)

        self.f0 = self.problem(self.x)
        self.g = self.f0.detach()
        self.g.requires_grad_(True)
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g,
                                create_graph=True))  # df/dx^t @ f0
        self.b = -self.dfdxt_g.detach()

        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += self.step_alpha * delta_x
        self.step_alpha = min(self.step_alpha * 1.2, 1.0)

    def reset_state(self):
        self.p = None
        self.rho = torch.ones(1)
        self.r_prev = None

    def run_CG(self, num_iter, x=None, eps=0.0):
        """Main conjugate gradient method"""

        # Apply forgetting factor
        if self.direction_forget_factor == 0:
            self.reset_state()
        elif self.p is not None:
            self.rho /= self.direction_forget_factor

        if x is None:
            r = self.b.clone()
        else:
            r = self.b - self.A(x)

        # Loop over iterations
        for ii in range(num_iter):

            z = self.problem.M1(r)  # Preconditioner

            rho1 = self.rho
            self.rho = self.ip(r, z)

            if self.p is None:
                self.p = z.clone()
            else:
                if self.fletcher_reeves:
                    beta = self.rho / rho1
                else:
                    rho2 = self.ip(self.r_prev, z)
                    beta = (self.rho - rho2) / rho1

                beta = beta.clamp(0)
                self.p = z + self.p * beta

            q = self.A(self.p)
            pq = self.ip(self.p, q)

            if self.standard_alpha:
                alpha = self.rho / pq
            else:
                alpha = self.ip(self.p, r) / pq

            # Save old r for PR formula
            if not self.fletcher_reeves:
                self.r_prev = r.clone()

            # Form new iterate
            if x is None:
                x = self.p * alpha
            else:
                x += self.p * alpha

            if ii < num_iter - 1:
                r -= q * alpha

        return x, []

    def A(self, x):
        dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                     self.g,
                                     x,
                                     retain_graph=True)
        return TensorList(
            torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))

    def ip(self, a, b):
        return self.problem.ip_input(a, b)