예제 #1
0
def loss(y_pred, y_true, group, metric=None):
    """
    Loss function given by a riemannian metric.
    """
    if metric is None:
        metric = group.left_invariant_metric
    loss = riemannian_metric.loss(y_pred, y_true, metric)
    return loss
예제 #2
0
def loss(y_pred, y_true, group, metric=None):
    """Compute loss given by riemannian metric.

    Parameters
    ----------
    y_pred
    y_true
    group
    metric

    Returns
    -------
    loss
    """
    if metric is None:
        metric = group.left_invariant_metric
    loss = riemannian_metric.loss(y_pred, y_true, metric)
    return loss
예제 #3
0
def loss(y_pred, y_true, group, metric=None):
    """Compute loss given by Riemannian metric.

    Parameters
    ----------
    y_pred : array-like, shape=[n_samples, {dimension, [n, n]}]
    y_true : array-like, shape=[n_samples, {dimension, [n, n]}]
        shape has to match y_pred
    group : LieGroup
    metric : RiemannianMetric, optional
        default: the left invariant metric of the Lie group

    Returns
    -------
    loss : array-like, shape=[n_samples, {dimension, [n, n]}]
        the squared (geodesic) distance between y_pred and y_true
    """
    if metric is None:
        metric = group.left_invariant_metric
    loss = riemannian_metric.loss(y_pred, y_true, metric)
    return loss
예제 #4
0
def loss(y_pred, y_true, group, metric=None):
    """Compute loss given by Riemannian metric.

    Parameters
    ----------
    y_pred : array-like, shape=[..., {dim, [n, n]}]
        Prediction.
    y_true : array-like, shape=[..., {dim, [n, n]}]
        Ground-truth.
        Shape has to match y_pred.
    group : LieGroup
    metric : RiemannianMetric
        Riemannian metric.
        Optional, defaults to the left invariant metric if None.

    Returns
    -------
    loss : array-like, shape=[..., {dim, [n, n]}]
        Squared (geodesic) distance between y_pred and y_true
    """
    if metric is None:
        metric = group.left_canonical_metric
    metric_loss = riemannian_metric.loss(y_pred, y_true, metric)
    return metric_loss
def train():
    criterion_mse = nn.MSELoss()

    param, det_size, _3D_vol, CT_vol, ray_proj_mov, corner_pt, norm_factor = input_param(
        CT_PATH, SEG_PATH, BATCH_SIZE, VOX_SPAC, zFlip)

    initmodel = ProST_init(param).to(device)
    model = RegiNet(param, det_size).to(device)

    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    scheduler = CyclicLR(optimizer,
                         base_lr=1e-6,
                         max_lr=1e-4,
                         step_size_up=100)

    if RESUME_EPOCH >= 0:
        print('Resuming model from epoch', RESUME_EPOCH)
        checkpoint = torch.load(RESUME_MODEL)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        START_EPOCH = RESUME_EPOCH + 1
        step_cnt = RESUME_EPOCH * ITER_NUM
    else:
        START_EPOCH = 0
        step_cnt = 0

    print('module parameters:', count_parameters(model))

    model.train()

    riem_grad_loss_list = []
    riem_grad_rot_loss_list = []
    riem_grad_trans_loss_list = []
    riem_dist_list = []
    riem_dist_mean_list = []
    mse_loss_list = []
    vecgrad_diff_list = []
    total_loss_list = []

    for epoch in range(START_EPOCH, 20000):
        ## Do Iterative Validation
        model.train()
        for iter in tqdm(range(ITER_NUM)):
            step_cnt = step_cnt + 1
            scheduler.step()
            # Get target  projection
            transform_mat3x4_gt, rtvec, rtvec_gt = init_rtvec_train(
                BATCH_SIZE, device)

            with torch.no_grad():
                target = initmodel(CT_vol, ray_proj_mov, transform_mat3x4_gt,
                                   corner_pt)
                min_tar, _ = torch.min(target.reshape(BATCH_SIZE, -1),
                                       dim=-1,
                                       keepdim=True)
                max_tar, _ = torch.max(target.reshape(BATCH_SIZE, -1),
                                       dim=-1,
                                       keepdim=True)
                target = (target.reshape(BATCH_SIZE, -1) -
                          min_tar) / (max_tar - min_tar)
                target = target.reshape(BATCH_SIZE, 1, det_size, det_size)

            # Do Projection and get two encodings
            encode_mov, encode_tar, proj_mov = model(_3D_vol, target, rtvec,
                                                     corner_pt)

            optimizer.zero_grad()
            # Calculate Net l2 Loss, L_N
            l2_loss = criterion_mse(encode_mov, encode_tar)

            # Find geodesic distance
            riem_dist = np.sqrt(
                riem.loss(rtvec.detach().cpu(),
                          rtvec_gt.detach().cpu(), METRIC))

            z = Variable(torch.ones(l2_loss.shape)).to(device)
            rtvec_grad = torch.autograd.grad(l2_loss,
                                             rtvec,
                                             grad_outputs=z,
                                             only_inputs=True,
                                             create_graph=True,
                                             retain_graph=True)[0]
            # Find geodesic gradient
            riem_grad = riem.grad(rtvec.detach().cpu(),
                                  rtvec_gt.detach().cpu(), METRIC)
            riem_grad = torch.tensor(riem_grad,
                                     dtype=torch.float,
                                     requires_grad=False,
                                     device=device)

            ### Translation Loss
            riem_grad_transnorm = riem_grad[:, 3:] / (
                torch.norm(riem_grad[:, 3:], dim=-1, keepdim=True) + EPS)
            rtvec_grad_transnorm = rtvec_grad[:, 3:] / (
                torch.norm(rtvec_grad[:, 3:], dim=-1, keepdim=True) + EPS)
            riem_grad_trans_loss = torch.mean(
                torch.sum((riem_grad_transnorm - rtvec_grad_transnorm)**2,
                          dim=-1))

            ### Rotation Loss
            riem_grad_rotnorm = riem_grad[:, :3] / (
                torch.norm(riem_grad[:, :3], dim=-1, keepdim=True) + EPS)
            rtvec_grad_rotnorm = rtvec_grad[:, :3] / (
                torch.norm(rtvec_grad[:, :3], dim=-1, keepdim=True) + EPS)
            riem_grad_rot_loss = torch.mean(
                torch.sum((riem_grad_rotnorm - rtvec_grad_rotnorm)**2, dim=-1))

            riem_grad_loss = riem_grad_trans_loss + riem_grad_rot_loss

            riem_grad_loss.backward()

            # Clip training gradient magnitude
            torch.nn.utils.clip_grad_norm(model.parameters(), clipping_value)
            optimizer.step()

            total_loss = l2_loss

            mse_loss_list.append(torch.mean(l2_loss).detach().item())
            riem_grad_loss_list.append(riem_grad_loss.detach().item())
            riem_grad_rot_loss_list.append(riem_grad_rot_loss.detach().item())
            riem_grad_trans_loss_list.append(
                riem_grad_trans_loss.detach().item())
            riem_dist_list.append(riem_dist)
            riem_dist_mean_list.append(np.mean(riem_dist))
            total_loss_list.append(total_loss.detach().item())
            vecgrad_diff = (rtvec_grad - riem_grad).detach().cpu().numpy()
            vecgrad_diff_list.append(vecgrad_diff)

            torch.cuda.empty_cache()

            cur_lr = float(scheduler.get_lr()[0])

            print('Train epoch: {} Iter: {} tLoss: {:.4f}, gLoss: {:.4f}/{:.2f}, gLoss_rot: {:.4f}/{:.2f}, gLoss_trans: {:.4f}/{:.2f}, LR: {:.4f}'.format(
                        epoch, iter, np.mean(total_loss_list), np.mean(riem_grad_loss_list), np.std(riem_grad_loss_list),\
                                     np.mean(riem_grad_rot_loss_list), np.std(riem_grad_rot_loss_list),\
                                     np.mean(riem_grad_trans_loss_list), np.std(riem_grad_trans_loss_list),
                        cur_lr, sys.stdout))

        if epoch % SAVE_MODEL_EVERY_EPOCH == 0:
            state = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                state,
                SAVE_PATH + '/checkpoint/vali_model' + str(epoch) + '.pt')

        riem_grad_loss_list = []
        riem_grad_rot_loss_list = []
        riem_grad_trans_loss_list = []
        riem_dist_list = []
        riem_dist_mean_list = []
        mse_loss_list = []
        vecgrad_diff_list = []