示例#1
0
def crop_image(img, center, scale, res, base=384):
    h = base * scale

    t = Translation(
        [
            res[0] * (-center[0] / h + .5),
            res[1] * (-center[1] / h + .5)
        ]) \
        .compose_after(
        Scale(
            (res[0] / h,
             res[1] / h)
        )).pseudoinverse()

    # Upper left point of original image
    ul = np.floor(t.apply([0, 0]))
    # Bottom right point of original image
    br = np.ceil(t.apply(res).astype(np.int))

    # crop and rescale
    cimg, trans = img.warp_to_shape(br - ul,
                                    Translation(-(br - ul) / 2 +
                                                (br + ul) / 2),
                                    return_transform=True)

    c_scale = np.min(cimg.shape) / np.mean(res)
    new_img = cimg.rescale(1 / c_scale).resize(res)

    return new_img, trans, c_scale
示例#2
0
def test_align_2d_translation():
    t_vec = np.array([1, 2])
    translation = Translation(t_vec)
    source = PointCloud(np.array([[0, 1], [1, 1], [-1, -5], [3, -5]]))
    target = translation.apply(source)
    # estimate the transform from source and target
    estimate = AlignmentTranslation(source, target)
    # check the estimates is correct
    assert_allclose(translation.h_matrix, estimate.h_matrix)
示例#3
0
def test_align_2d_translation_set_h_matrix_raises_notimplemented_error():
    t_vec = np.array([1, 2])
    translation = Translation(t_vec)
    source = PointCloud(np.array([[0, 1], [1, 1], [-1, -5], [3, -5]]))
    target = translation.apply(source)
    # estimate the transform from source to source..
    estimate = AlignmentTranslation(source, source)
    # and change the target.
    estimate.set_h_matrix(translation.h_matrix)
示例#4
0
def test_align_2d_translation_set_h_matrix_raises_notimplemented_error():
    t_vec = np.array([1, 2])
    translation = Translation(t_vec)
    source = PointCloud(np.array([[0, 1], [1, 1], [-1, -5], [3, -5]]))
    target = translation.apply(source)
    # estimate the transform from source to source..
    estimate = AlignmentTranslation(source, source)
    # and change the target.
    estimate.set_h_matrix(translation.h_matrix)
示例#5
0
def test_align_2d_translation():
    t_vec = np.array([1, 2])
    translation = Translation(t_vec)
    source = PointCloud(np.array([[0, 1], [1, 1], [-1, -5], [3, -5]]))
    target = translation.apply(source)
    # estimate the transform from source and target
    estimate = AlignmentTranslation(source, target)
    # check the estimates is correct
    assert_allclose(translation.h_matrix, estimate.h_matrix)
示例#6
0
def test_align_2d_translation_from_vector_inplace():
    t_vec = np.array([1, 2])
    translation = Translation(t_vec)
    source = PointCloud(np.array([[0, 1], [1, 1], [-1, -5], [3, -5]]))
    target = translation.apply(source)
    # estimate the transform from source to source..
    estimate = AlignmentTranslation(source, source)
    # and update from_vector
    estimate.from_vector_inplace(t_vec)
    # check the estimates is correct
    assert_allclose(target.points, estimate.target.points)
示例#7
0
def test_align_2d_translation_from_vector_inplace():
    t_vec = np.array([1, 2])
    translation = Translation(t_vec)
    source = PointCloud(np.array([[0, 1], [1, 1], [-1, -5], [3, -5]]))
    target = translation.apply(source)
    # estimate the transform from source to source..
    estimate = AlignmentTranslation(source, source)
    # and update from_vector
    estimate._from_vector_inplace(t_vec)
    # check the estimates is correct
    assert_allclose(target.points, estimate.target.points)
def mask_pc(pc):
    t = Translation(-pc.centre())
    p = t.apply(pc)
    (y1, x1), (y2, x2) = p.bounds()

    a, b = np.meshgrid(np.arange(np.floor(y1), np.ceil(y2)),
                       np.arange(np.floor(x1), np.ceil(x2)))

    mask = np.vstack([a.flatten(), b.flatten()]).T

    return PointCloud(t.pseudoinverse().apply(mask[matpath(
        p.points).contains_points(mask)]).astype(int))
示例#9
0
def chain_compose_before_tps_test():
    a = PointCloud(np.random.random([10, 2]))
    b = PointCloud(np.random.random([10, 2]))
    tps = ThinPlateSplines(a, b)

    t = Translation([3, 4])
    s = Scale([4, 2])
    chain = TransformChain([t, s])
    chain_mod = chain.compose_before(tps)

    points = PointCloud(np.random.random([10, 2]))

    manual_res = tps.apply(s.apply(t.apply(points)))
    chain_res = chain_mod.apply(points)
    assert(np.all(manual_res.points == chain_res.points))
示例#10
0
def test_chain_compose_after_inplace_tps():
    a = PointCloud(np.random.random([10, 2]))
    b = PointCloud(np.random.random([10, 2]))
    tps = ThinPlateSplines(a, b)

    t = Translation([3, 4])
    s = Scale([4, 2])
    chain = TransformChain([t, s])
    chain.compose_after_inplace(tps)

    points = PointCloud(np.random.random([10, 2]))

    manual_res = s.apply(t.apply(tps.apply(points)))
    chain_res = chain.apply(points)
    assert (np.all(manual_res.points == chain_res.points))
示例#11
0
def test_align_2d_translation_from_vector():
    t_vec = np.array([1, 2])
    translation = Translation(t_vec)
    source = PointCloud(np.array([[0, 1], [1, 1], [-1, -5], [3, -5]]))
    target = translation.apply(source)
    # estimate the transform from source to source..
    estimate = AlignmentTranslation(source, source)
    # and update from_vector
    new_est = estimate.from_vector(t_vec)
    # check the original is unchanged
    assert_allclose(estimate.source.points, source.points)
    assert_allclose(estimate.target.points, source.points)
    # check the new estimate has the source and target correct
    assert_allclose(new_est.source.points, source.points)
    assert_allclose(new_est.target.points, target.points)
示例#12
0
def test_align_2d_translation_from_vector():
    t_vec = np.array([1, 2])
    translation = Translation(t_vec)
    source = PointCloud(np.array([[0, 1], [1, 1], [-1, -5], [3, -5]]))
    target = translation.apply(source)
    # estimate the transform from source to source..
    estimate = AlignmentTranslation(source, source)
    # and update from_vector
    new_est = estimate.from_vector(t_vec)
    # check the original is unchanged
    assert_allclose(estimate.source.points, source.points)
    assert_allclose(estimate.target.points, source.points)
    # check the new estimate has the source and target correct
    assert_allclose(new_est.source.points, source.points)
    assert_allclose(new_est.target.points, target.points)
示例#13
0
def test_init_from_pointcloud_return_transform():
    correct_tr = Translation([5, 5])
    pc = correct_tr.apply(PointCloud.init_2d_grid((10, 10)))
    im, tr = Image.init_from_pointcloud(pc, return_transform=True)
    assert im.shape == (9, 9)
    assert_allclose(tr.as_vector(), -correct_tr.as_vector())
示例#14
0
def test_translation():
    t_vec = np.array([1, 2, 3])
    starting_vector = np.random.rand(10, 3)
    transform = Translation(t_vec)
    transformed = transform.apply(starting_vector)
    assert_allclose(starting_vector + t_vec, transformed)
示例#15
0
def test_init_from_pointcloud_return_transform():
    correct_tr = Translation([5, 5])
    pc = correct_tr.apply(PointCloud.init_2d_grid((10, 10)))
    im, tr = Image.init_from_pointcloud(pc, return_transform=True)
    assert im.shape == (9, 9)
    assert_allclose(tr.as_vector(), -correct_tr.as_vector())
示例#16
0
def rescale_images_to_reference_shape(images,
                                      group,
                                      reference_shape,
                                      tight_mask=True,
                                      sd=svs_shape,
                                      target_group=None,
                                      verbose=False):
    r"""
    """
    _has_lms_align = False
    _n_align_points = None
    _is_mc = False
    group_align = group
    _db_path = images[0].path.parent
    reference_align_shape = reference_shape
    n_landmarks = reference_shape.n_points
    # Normalize the scaling of all images wrt the reference_shape size
    for i in images:
        if 'LMS' in i.landmarks.keys():
            _has_lms_align = True
            i.landmarks['align'] = i.landmarks['LMS']
            if not _n_align_points:
                _n_align_points = i.landmarks['align'].lms.n_points

    if _has_lms_align:
        group_align = 'align'
        reference_align_shape = PointCloud(
            reference_shape.points[:_n_align_points])
        reference_shape = PointCloud(reference_shape.points[_n_align_points:])
    else:
        group_align = '_nicp'
        for i in images:
            source_shape = TriMesh(reference_shape.points)
            _, points_corr = nicp(source_shape, i.landmarks[group].lms)
            i.landmarks[group_align] = PointCloud(
                i.landmarks[group].lms.points[points_corr])

    print('  - Normalising')
    normalized_images = [
        i.rescale_to_pointcloud(reference_align_shape, group=group_align)
        for i in images
    ]

    # Global Parameters
    alpha = 30
    pdm = 0
    lms_shapes = [i.landmarks[group_align].lms for i in normalized_images]
    shapes = [i.landmarks[group].lms for i in normalized_images]
    n_shapes = len(shapes)

    # Align Shapes Using ICP
    aligned_shapes, target_shape, _removed_transform, _icp_transform, _icp\
        = align_shapes(shapes, reference_shape, lms_shapes=lms_shapes, align_target=reference_align_shape)
    # Build Reference Frame from Aligned Shapes
    bound_list = []
    for s in [reference_shape] + aligned_shapes.tolist():
        bmin, bmax = s.bounds()
        bound_list.append(bmin)
        bound_list.append(bmax)
        bound_list.append(np.array([bmin[0], bmax[1]]))
        bound_list.append(np.array([bmax[0], bmin[1]]))
    bound_list = PointCloud(np.array(bound_list))

    scales = np.max(bound_list.points, axis=0) - np.min(bound_list.points,
                                                        axis=0)
    max_scale = np.max(scales)
    bound_list = PointCloud(
        np.array([[max_scale, max_scale], [max_scale, 0], [0, max_scale],
                  [0, 0]]))

    reference_frame = build_reference_frame(bound_list, boundary=15)

    # Translation between reference shape and aliened shapes
    align_centre = target_shape.centre_of_bounds()
    align_t = Translation(reference_frame.centre() - align_centre)

    _rf_align = Translation(align_centre - reference_frame.centre())

    # Set All True Pixels for Mask
    reference_frame.mask.pixels = np.ones(reference_frame.mask.pixels.shape,
                                          dtype=np.bool)

    # Create Cache Directory
    home_dir = os.getcwd()
    dir_hex = uuid.uuid1()

    sd_path_in = '{}/shape_discriptor'.format(
        _db_path) if _db_path else '{}/.cache/{}/sd_training'.format(
            home_dir, dir_hex)
    sd_path_out = sd_path_in

    matE = MatlabExecuter()
    mat_code_path = '/vol/atlas/homes/yz4009/gitdev/mfsfdev'

    # Skip building svs is path specified
    _build_shape_desc(sd_path_in,
                      normalized_images,
                      reference_shape,
                      aligned_shapes,
                      align_t,
                      reference_frame,
                      _icp_transform,
                      _is_mc=_is_mc,
                      group=group,
                      target_align_shape=reference_align_shape,
                      _shape_desc=sd,
                      align_group=group_align,
                      target_group=target_group)

    # self._build_trajectory_basis(sample_groups, target_shape,
    #     aligned_shapes, dense_reference_shape, align_t)

    # Call Matlab to Build Flows
    if not isfile('{}/result.mat'.format(sd_path_in)):
        print('  - Building Shape Flow')
        matE.cd(mat_code_path)
        ext = 'gif'
        isLms = _has_lms_align + 0
        isBasis = 0
        fstr =  'gpuDevice(1);' \
                'addpath(\'{0}/{1}\');' \
                'addpath(\'{0}/{2}\');' \
                'build_flow(\'{3}\', \'{4}\', \'{5}\', {6}, {7}, ' \
                '{8}, \'{3}/{9}\', {10}, {11}, {14}, {15}, {12}, \'{13}\')'.format(
                    mat_code_path, 'cudafiles', 'tools',
                    sd_path_in, sd_path_out, 'sd_%04d.{}'.format(ext),
                    0,
                    1, n_shapes, 'bas.mat',
                    alpha, pdm, 30, 'sd_%04d_lms.pts', isBasis, isLms
               )
        sys.stderr.write(fstr)
        sys.stderr.write(fstr.replace('build_flow', 'build_flow_test'))
        p = matE.run_function(fstr)
        p.wait()
    else:
        sd_path_out = sd_path_in

    # Retrieve Results
    mat = sio.loadmat('{}/result.mat'.format(sd_path_out))

    _u, _v = mat['u'], mat['v']

    # Build Transforms
    print("  - Build Transform")
    transforms = []
    for i in range(n_shapes):
        transforms.append(OpticalFlowTransform(_u[:, :, i], _v[:, :, i]))

    # build dense shapes
    print("  - Build Dense Shapes")

    testing_points = reference_frame.mask.true_indices()
    ref_sparse_lms = align_t.apply(reference_shape)
    close_mask = BooleanImage(
        matpath(ref_sparse_lms.points).contains_points(testing_points).reshape(
            reference_frame.mask.mask.shape))

    if tight_mask:
        reference_frame.mask = close_mask
    else:
        reference_frame.landmarks['sparse'] = ref_sparse_lms
        reference_frame.constrain_mask_to_landmarks(group='sparse')

    # Get Dense Shape from Masked Image
    dense_reference_shape = PointCloud(
        np.vstack((align_t.apply(reference_align_shape).points,
                   align_t.apply(reference_shape).points,
                   reference_frame.mask.true_indices())))

    # Set Dense Shape as Reference Landmarks
    reference_frame.landmarks['source'] = dense_reference_shape
    dense_shapes = []
    for i, t in enumerate(transforms):
        warped_points = t.apply(dense_reference_shape)
        dense_shape = warped_points
        dense_shapes.append(dense_shape)

    ni = []
    for i, ds, t in zip(normalized_images, dense_shapes, _removed_transform):
        img = i.warp_to_shape(reference_frame.shape,
                              _rf_align.compose_before(t),
                              warp_landmarks=True)
        img.landmarks[group] = ds
        ni.append(img)

    return ni, transforms, reference_frame, n_landmarks, _n_align_points, _removed_transform, normalized_images, _rf_align, reference_shape, [
        align_t
        # _rf_align, _removed_transform, aligned_shapes, target_shape,
        # reference_frame, dense_reference_shape, testing_points,
        # align_t, normalized_images, shapes, lms_shapes,
        # reference_shape, reference_align_shape
    ]
示例#17
0
def non_rigid_icp_generator(source, target, eps=1e-3,
                            stiffness_weights=None, data_weights=None,
                            landmark_group=None, landmark_weights=None,
                            v_i_update_func=None, verbose=False):
    r"""
    Deforms the source trimesh to align with to optimally the target.
    """
    # If landmarks are provided, we should always start with a simple
    # AlignmentSimilarity between the landmarks to initialize optimally.

    if landmark_group is not None:
        if verbose:
            print("'{}' landmarks will be used as "
                  "a landmark constraint.".format(landmark_group))
            print("performing similarity alignment using landmarks")
        lm_align = AlignmentSimilarity(source.landmarks[landmark_group],
                                       target.landmarks[landmark_group]).as_non_alignment()
        source = lm_align.apply(source)

    # Scale factors completely change the behavior of the algorithm - always
    # rescale the source down to a sensible size (so it fits inside box of
    # diagonal 1) and is centred on the origin. We'll undo this after the fit
    # so the user can use whatever scale they prefer.
    
    #tr = Translation(-1 * source.centre())
    #sc = UniformScale(1.0 / np.sqrt(np.sum(source.range() ** 2)), 3)
    
    #tr_t = Translation(-1 * target.centre())
    #sc_t = UniformScale(1.0 / np.sqrt(np.sum(target.range() ** 2)), 3) 
    
    #tr = Translation([0, 0, 0])
    #sc = UniformScale(1.0, 3)

    #prepare = tr.compose_before(sc)
    #prepare_t = tr_t.compose_before(sc_t)

    #source = prepare.apply(source)
    #target = prepare_t.apply(target)
    
    #m3io.export_mesh(source, '/data/tmp/source.obj', overwrite = True)
    #m3io.export_mesh(target, '/data/tmp/target.obj', overwrite = True)
    
    #t = AlignmentSimilarity(source.landmarks['LJSON'], target.landmarks['LJSON'])

    #source = t.apply(source)
    

    # store how to undo the similarity transform
    # restore = prepare.pseudoinverse()
    
    # restore source to target scale
    #restore = prepare_t.pseudoinverse()
    restore = Translation([0, 0, 0])

    n_dims = source.n_dims
    # Homogeneous dimension (1 extra for translation effects)
    h_dims = n_dims + 1
    points, trilist = source.points, source.trilist
    n = points.shape[0]  # record number of points

    # ========================================================================
    edge_tris = target.boundary_tri_index() # SOURCE???
    # ========================================================================

    M_s, unique_edge_pairs = node_arc_incidence_matrix(source)
    #print('M_s {}'.format(M_s.shape))

    # weight matrix
    G = np.identity(n_dims + 1)

    M_kron_G_s = sp.kron(M_s, G)
    #print('M_kron_G_s {}'.format(M_kron_G_s.shape))

    # build octree for finding closest points on target.
    target_vtk = trimesh_to_vtk(target)
    closest_points_on_target = VTKClosestPointLocator(target_vtk)

    # save out the target normals. We need them for the weight matrix.
    target_tri_normals = target.tri_normals()

    # init transformation
    X_prev = np.tile(np.zeros((n_dims, h_dims)), n).T
    v_i = points

    if stiffness_weights is not None:
        if verbose:
            print('using user-defined stiffness_weights')
        validate_weights('stiffness_weights', stiffness_weights,
                         source.n_points, verbose=verbose)
    else:
        # these values have been empirically found to perform well for well
        # rigidly aligned facial meshes
        stiffness_weights = [50, 20, 5, 2, 0.8, 0.5, 0.35, 0.2]
        if verbose:
            print('using default '
                  'stiffness_weights: {}'.format(stiffness_weights))

    n_iterations = len(stiffness_weights)

    if landmark_weights is not None:
        if verbose:
            print('using user defined '
                  'landmark_weights: {}'.format(landmark_weights))
    elif landmark_group is not None:
        # these values have been empirically found to perform well for well
        # rigidly aligned facial meshes
        landmark_weights = [5, 2, .5, 0, 0, 0, 0, 0]
        if verbose:
            print('using default '
                  'landmark_weights: {}'.format(landmark_weights))
    else:
        # no landmark_weights provided - no landmark_group in use. We still
        # need a landmark group for the iterator
        landmark_weights = [None] * n_iterations

    # We should definitely have some landmark weights set now - check the
    # number is correct.
    # Note we say verbose=False, as we have done custom reporting above, and
    # per-vertex landmarks are not supported.
    validate_weights('landmark_weights', landmark_weights, source.n_points,
                     n_iterations=n_iterations, verbose=False)

    if data_weights is not None:
        if verbose:
            print('using user-defined data_weights')
        validate_weights('data_weights', data_weights,
                         source.n_points, n_iterations=n_iterations,
                         verbose=verbose)
    else:
        data_weights = [None] * n_iterations
        if verbose:
            print('Not customising data_weights')

    # we need to prepare some indices for efficient construction of the D
    # sparse matrix.
    row = np.hstack((np.repeat(np.arange(n)[:, None], n_dims, axis=1).ravel(),
                     np.arange(n)))

    x = np.arange(n * h_dims).reshape((n, h_dims))
    col = np.hstack((x[:, :n_dims].ravel(),
                     x[:, n_dims]))
    o = np.ones(n)

    if landmark_group is not None:
        source_lm_index = source.distance_to(
            source.landmarks[landmark_group]).argmin(axis=0)
        target_lms = target.landmarks[landmark_group]
        U_L = target_lms.points
        n_landmarks = target_lms.n_points
        lm_mask = np.in1d(row, source_lm_index)
        col_lm = col[lm_mask]
        # pull out the rows for the lms - but the values are
        # all wrong! need to map them back to the order of the landmarks
        row_lm_to_fix = row[lm_mask]
        source_lm_index_l = list(source_lm_index)
        row_lm = np.array([source_lm_index_l.index(r) for r in row_lm_to_fix])

    for i, (alpha, beta, gamma) in enumerate(zip(stiffness_weights,
                                                 landmark_weights,
                                                 data_weights), 1):
        alpha_is_per_vertex = isinstance(alpha, np.ndarray)
        if alpha_is_per_vertex:
            # stiffness is provided per-vertex
            if alpha.shape[0] != source.n_points:
                raise ValueError()
            alpha_per_edge = alpha[unique_edge_pairs].mean(axis=1)
            alpha_M_s = sp.diags(alpha_per_edge).dot(M_s)
            alpha_M_kron_G_s = sp.kron(alpha_M_s, G)
        else:
            # stiffness is global - just a scalar multiply. Note that here
            # we don't have to recalculate M_kron_G_s
            alpha_M_kron_G_s = alpha * M_kron_G_s

        if verbose:
            a_str = (alpha if not alpha_is_per_vertex
                     else 'min: {:.2f}, max: {:.2f}'.format(alpha.min(),
                                                            alpha.max()))
            i_str = '{}/{}: stiffness: {}'.format(i, len(stiffness_weights), a_str)
            if landmark_group is not None:
                i_str += '  lm_weight: {}'.format(beta)
            print(i_str)

        j = 0
        
        while True:  # iterate until convergence
            j += 1  # track the iterations for this alpha/landmark weight
            
            # find nearest neighbour and the normals
            U, tri_indices = closest_points_on_target(v_i)

            # ---- WEIGHTS ----
            # 1.  Edges
            # Are any of the corresponding tris on the edge of the target?
            # Where they are we return a false weight (we *don't* want to
            # include these points in the solve)
            w_i_e = np.in1d(tri_indices, edge_tris, invert=True)

            # 2. Normals
            # Calculate the normals of the current v_i
            v_i_tm = TriMesh(v_i, trilist=trilist)
            v_i_n = v_i_tm.vertex_normals()
            # Extract the corresponding normals from the target
            u_i_n = target_tri_normals[tri_indices]
            # If the dot of the normals is lt 0.9 don't contrib to deformation
            w_i_n = (u_i_n * v_i_n).sum(axis=1) > 0.9

            # 3. Self-intersection
            # This adds approximately 12% to the running cost and doesn't seem
            # to be very critical in helping mesh fitting performance so for
            # now it's removed. Revisit later.
            # # Build an intersector for the current deformed target
            # intersect = build_intersector(to_vtk(v_i_tm))
            # # budge the source points 1% closer to the target
            # source = v_i + ((U - v_i) * 0.5)
            # # if the vector from source to target intersects the deformed
            # # template we don't want to include it in the optimisation.
            # problematic = [i for i, (s, t) in enumerate(zip(source, U))
            #                if len(intersect(s, t)[0]) > 0]
            # print(len(problematic) * 1.0 / n)
            # w_i_i = np.ones(v_i_tm.n_points, dtype=np.bool)
            # w_i_i[problematic] = False

            # Form the overall w_i from the normals, edge case
            # for now disable the edge constraint (it was noisy anyway)
            #w_i = w_i_n

            w_i = np.logical_and(w_i_n, w_i_e).astype(np.float)

            # we could add self intersection at a later date too...
            # w_i = np.logical_and(np.logical_and(w_i_n,
            #                                     w_i_e,
            #                                     w_i_i).astype(np.float))

            # ===========================
            prop_w_i = (n - w_i.sum() * 1.0) / n
            prop_w_i_n = (n - w_i_n.sum() * 1.0) / n
            prop_w_i_e = (n - w_i_e.sum() * 1.0) / n

            if gamma is not None:
                w_i = w_i * gamma

            # Build the sparse diagonal weight matrix
            W_s = sp.diags(w_i.astype(np.float)[None, :], [0])
            #print('W_s {}'.format(W_s.shape))

            data = np.hstack((v_i.ravel(), o))
            D_s = sp.coo_matrix((data, (row, col)))

            to_stack_A = [alpha_M_kron_G_s, W_s.dot(D_s)]
            to_stack_B = [np.zeros((alpha_M_kron_G_s.shape[0], n_dims)),
                          U * w_i[:, None]]  # nullify nearest points by w_i

            if landmark_group is not None:
                D_L = sp.coo_matrix((data[lm_mask], (row_lm, col_lm)),
                                    shape=(n_landmarks, D_s.shape[1]))
                to_stack_A.append(beta * D_L)
                to_stack_B.append(beta * U_L)


            A_s = sp.vstack(to_stack_A).tocsr()
            #print('A_s {}'.format(A_s.shape))

            B_s = sp.vstack(to_stack_B).tocsr()
            
            #try:
            X = spsolve(A_s, B_s)
            #except CholmodError:
            #    m3io.export_mesh(v_i_tm, 'problematic.obj', overwrite = True)
            #    exit(0)

            # deform template
            v_i_prev = v_i
            v_i = D_s.dot(X)
            delta_v_i = v_i - v_i_prev

            if v_i_update_func:
                # custom logic is provided to update the current template
                # deformation. This is typically used by Active NICP.

                # take the v_i points matrix and convert back to a TriMesh in
                # the original space
                def_template = restore.apply(source.from_vector(v_i.ravel()))

                # perform the update
                updated_def_template = v_i_update_func(def_template)

                # convert back to points in the NICP space
                v_i = prepare.apply(updated_def_template.points)

            err = np.linalg.norm(X_prev - X, ord='fro')
            stop_criterion = err / np.sqrt(np.size(X_prev))

            if landmark_group is not None:
                src_lms = v_i[source_lm_index]
                lm_err = np.sqrt((src_lms - U_L) ** 2).sum(axis=1).mean()

            if verbose:
                v_str = (' - {} stop crit: {:.3f}  '
                         'total: {:.0%}  norms: {:.0%}  '
                         'edges: {:.0%}'.format(j, stop_criterion,
                                                prop_w_i, prop_w_i_n,
                                                prop_w_i_e))
                if landmark_group is not None:
                    v_str += '  lm_err: {:.4f}'.format(lm_err)

                print(v_str)

            X_prev = X

            # track the progress of the algorithm per-iteration
            info_dict = {
                'alpha': alpha,
                'iteration': j,
                'prop_omitted': prop_w_i,
                'prop_omitted_norms': prop_w_i_n,
                'prop_omitted_edges': prop_w_i_e,
                'delta': err,
                'mask_normals': w_i_n,
                'mask_edges': w_i_e,
                'mask_all': w_i,
                'nearest_points': restore.apply(U),
                'deformation_per_step': delta_v_i
            }

            current_instance = source.copy()
            current_instance.points = v_i.copy()

            if landmark_group:
                info_dict['beta'] = beta
                info_dict['lm_err'] = lm_err
                current_instance.landmarks[landmark_group] = PointCloud(src_lms)

            yield restore.apply(current_instance), info_dict

            if stop_criterion < eps:
                break
def augment_face_image(img, image_size=256, crop_size=248, angle_range=30, flip=True, warp_mode='constant'):
    """basic image augmentation: random crop, rotation and horizontal flip"""

    #from menpo
    def round_image_shape(shape, round):
        if round not in ['ceil', 'round', 'floor']:
            raise ValueError('round must be either ceil, round or floor')
        # Ensure that the '+' operator means concatenate tuples
        return tuple(getattr(np, round)(shape).astype(np.int))

    # taken from MDM
    def mirror_landmarks_68(lms, im_size):
        return PointCloud(abs(np.array([0, im_size[1]]) - lms.as_vector(
        ).reshape(-1, 2))[mirrored_parts_68])

    # taken from MDM
    def mirror_image(im):
        im = im.copy()
        im.pixels = im.pixels[..., ::-1].copy()

        for group in im.landmarks:
            lms = im.landmarks[group]
            if lms.points.shape[0] == 68:
                im.landmarks[group] = mirror_landmarks_68(lms, im.shape)

        return im

    flip_rand = np.random.random() > 0.5
    #     rot_rand = np.random.random() > 0.5
    #     crop_rand = np.random.random() > 0.5
    rot_rand = True  # like ECT
    crop_rand = True  # like ECT

    if crop_rand:
        lim = image_size - crop_size
        min_crop_inds = np.random.randint(0, lim, 2)
        max_crop_inds = min_crop_inds + crop_size
        img = img.crop(min_crop_inds, max_crop_inds)

    if flip and flip_rand:
        img = mirror_image(img)

    if rot_rand:
        rot_angle = 2 * angle_range * np.random.random_sample() - angle_range
        # img = img.rotate_ccw_about_centre(rot_angle)

        # Get image's bounding box coordinates
        bbox = bounding_box((0, 0), [img.shape[0] - 1, img.shape[1] - 1])
        # Translate to origin and rotate counter-clockwise
        trans = Translation(-img.centre(),
                            skip_checks=True).compose_before(
            Rotation.init_from_2d_ccw_angle(rot_angle, degrees=True))
        rotated_bbox = trans.apply(bbox)
        # Create new translation so that min bbox values go to 0
        t = Translation(-rotated_bbox.bounds()[0])
        trans.compose_before_inplace(t)
        rotated_bbox = trans.apply(bbox)
        # Output image's shape is the range of the rotated bounding box
        # while respecting the users rounding preference.
        shape = round_image_shape(rotated_bbox.range() + 1, 'round')

        img = img.warp_to_shape(
            shape, trans.pseudoinverse(), warp_landmarks=True, mode=warp_mode)

    img = img.resize([image_size, image_size])

    return img