skeletor.post

The skeletor.post module contains functions to post-process skeletons after skeletonization.

Fixing issues with skeletons

Depending on your mesh, pre-processing and the parameters you chose for skeletonization, chances are that your skeleton will not come out perfectly.

skeletor.post.clean_up can help you solve some potential issues:

  • skeleton nodes (vertices) that outside or right on the surface instead of centered inside the mesh
  • superfluous "hairs" on otherwise straight bits

skeletor.post.smooth will smooth out the skeleton.

skeletor.post.despike can help you remove spikes in the skeleton where single nodes are out of aligment.

skeletor.post.remove_bristles will remove bristles from the skeleton.

Computing radius information

Only skeletor.skeletonize.by_wavefront() provides radii off the bat. For all other methods, you might want to run skeletor.post.radii can help you (re-)generate radius information for the skeletons.

 1#    This script is part of skeletor (http://www.github.com/navis-org/skeletor).
 2#    Copyright (C) 2018 Philipp Schlegel
 3#
 4#    This program is free software: you can redistribute it and/or modify
 5#    it under the terms of the GNU General Public License as published by
 6#    the Free Software Foundation, either version 3 of the License, or
 7#    (at your option) any later version.
 8#
 9#    This program is distributed in the hope that it will be useful,
10#    but WITHOUT ANY WARRANTY; without even the implied warranty of
11#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12#    GNU General Public License for more details.
13#
14#    You should have received a copy of the GNU General Public License
15#    along with this program.
16
17r"""
18The `skeletor.post` module contains functions to post-process skeletons after
19skeletonization.
20
21### Fixing issues with skeletons
22
23Depending on your mesh, pre-processing and the parameters you chose for
24skeletonization, chances are that your skeleton will not come out perfectly.
25
26`skeletor.post.clean_up` can help you solve some potential issues:
27
28- skeleton nodes (vertices) that outside or right on the surface instead of
29  centered inside the mesh
30- superfluous "hairs" on otherwise straight bits
31
32`skeletor.post.smooth` will smooth out the skeleton.
33
34`skeletor.post.despike` can help you remove spikes in the skeleton where
35single nodes are out of aligment.
36
37`skeletor.post.remove_bristles` will remove bristles from the skeleton.
38
39### Computing radius information
40
41Only `skeletor.skeletonize.by_wavefront()` provides radii off the bat. For all
42other methods, you might want to run `skeletor.post.radii` can help you
43(re-)generate radius information for the skeletons.
44
45"""
46
47from .radiusextraction import radii
48from .postprocessing import clean_up, smooth, despike, remove_bristles, recenter_vertices, fix_outside_edges
49
50__docformat__ = "numpy"
51__all__ = ["radii", "clean_up", "smooth", "despike", "remove_bristles", "recenter_vertices", "fix_outside_edges"]
def radii( s, mesh=None, method='knn', aggregate='mean', validate=False, **kwargs):
 29def radii(s, mesh=None, method='knn', aggregate='mean', validate=False, **kwargs):
 30    """Extract radii for given skeleton table.
 31
 32    Important
 33    ---------
 34    This function really only produces useful radii if the skeleton is centered
 35    inside the mesh. `by_wavefront` does that by default whereas all other
 36    skeletonization methods don't. Your best bet to get centered skeletons is
 37    to contract the mesh first (`sk.pre.contract`).
 38
 39    Parameters
 40    ----------
 41    s :         skeletor.Skeleton
 42                Skeleton to clean up.
 43    mesh :      trimesh.Trimesh, optional
 44                Original mesh (e.g. before contraction). If not provided will
 45                use the mesh associated with ``s``.
 46    method :    "knn" | "ray"
 47                Whether and how to add radius information to each node::
 48
 49                    - "knn" uses k-nearest-neighbors to get radii: fast but
 50                      potential for being very wrong
 51                    - "ray" uses ray-casting to get radii: slower but sometimes
 52                      less wrong
 53
 54    aggregate : "mean" | "median" | "max" | "min" | "percentile75"
 55                Function used to aggregate radii over sample (i.e. across
 56                k nearest-neighbors or ray intersections)
 57    validate :  bool
 58                If True, will try to fix potential issues with the mesh
 59                (e.g. infinite values, duplicate vertices, degenerate faces)
 60                before skeletonization. Note that this might make changes to
 61                your mesh inplace!
 62    **kwargs
 63                Keyword arguments are passed to the respective method:
 64
 65                For method "knn"::
 66
 67                    n :             int (default 5)
 68                                    Radius will be the mean over n nearest-neighbors.
 69
 70                For method "ray"::
 71
 72                    n_rays :        int (default 20)
 73                                    Number of rays to cast for each node.
 74                    projection :    "sphere" (default) | "tangents"
 75                                    Whether to cast rays in a sphere around each node or in a
 76                                    circle orthogonally to the node's tangent vector.
 77                    fallback :      "knn" (default) | None | number
 78                                    If a point is outside or right on the surface of the mesh
 79                                    the raycasting will return nonesense results. We can either
 80                                    ignore those cases (``None``), assign a arbitrary number or
 81                                    we can fall back to radii from k-nearest-neighbors (``knn``).
 82
 83    Returns
 84    -------
 85    None
 86                    But attaches `radius` to the skeleton's SWC table. Existing
 87                    values are replaced!
 88
 89    """
 90    if isinstance(mesh, type(None)):
 91        mesh = s.mesh
 92
 93    mesh = make_trimesh(mesh, validate=True)
 94
 95    if method == 'knn':
 96        radius = get_radius_knn(s.swc[['x', 'y', 'z']].values,
 97                                aggregate=aggregate,
 98                                mesh=mesh, **kwargs)
 99    elif method == 'ray':
100        radius = get_radius_ray(s.swc,
101                                mesh=mesh,
102                                aggregate=aggregate,
103                                **kwargs)
104    else:
105        raise ValueError(f'Unknown method "{method}"')
106
107    s.swc['radius'] = radius
108
109    return

Extract radii for given skeleton table.

Important

This function really only produces useful radii if the skeleton is centered inside the mesh. by_wavefront does that by default whereas all other skeletonization methods don't. Your best bet to get centered skeletons is to contract the mesh first (sk.pre.contract).

Parameters
  • s (skeletor.Skeleton): Skeleton to clean up.
  • mesh (trimesh.Trimesh, optional): Original mesh (e.g. before contraction). If not provided will use the mesh associated with s.
  • method ("knn" | "ray"): Whether and how to add radius information to each node::

    - "knn" uses k-nearest-neighbors to get radii: fast but
      potential for being very wrong
    - "ray" uses ray-casting to get radii: slower but sometimes
      less wrong
    
  • aggregate ("mean" | "median" | "max" | "min" | "percentile75"): Function used to aggregate radii over sample (i.e. across k nearest-neighbors or ray intersections)
  • validate (bool): If True, will try to fix potential issues with the mesh (e.g. infinite values, duplicate vertices, degenerate faces) before skeletonization. Note that this might make changes to your mesh inplace!
  • **kwargs: Keyword arguments are passed to the respective method:

For method "knn"::

n :             int (default 5)
                Radius will be the mean over n nearest-neighbors.

For method "ray"::

n_rays :        int (default 20)
                Number of rays to cast for each node.
projection :    "sphere" (default) | "tangents"
                Whether to cast rays in a sphere around each node or in a
                circle orthogonally to the node's tangent vector.
fallback :      "knn" (default) | None | number
                If a point is outside or right on the surface of the mesh
                the raycasting will return nonesense results. We can either
                ignore those cases (``None``), assign a arbitrary number or
                we can fall back to radii from k-nearest-neighbors (``knn``).
Returns
  • None: But attaches radius to the skeleton's SWC table. Existing values are replaced!
def clean_up(s, mesh=None, validate=False, inplace=False, **kwargs):
30def clean_up(s, mesh=None, validate=False, inplace=False, **kwargs):
31    """Clean up the skeleton.
32
33    This function bundles a bunch of procedures to clean up the skeleton:
34
35      1. Remove twigs that are running parallel to their parent branch
36      2. Move nodes outside the mesh back inside (or at least snap to surface)
37
38    Note that this is not a magic bullet and some of this will not work (well)
39    if the original mesh was degenerate (e.g. internal faces or not watertight)
40    to begin with.
41
42    Parameters
43    ----------
44    s :         skeletor.Skeleton
45                Skeleton to clean up.
46    mesh :      trimesh.Trimesh, optional
47                Original mesh (e.g. before contraction). If not provided will
48                use the mesh associated with ``s``.
49    validate :  bool
50                If True, will try to fix potential issues with the mesh
51                (e.g. infinite values, duplicate vertices, degenerate faces)
52                before cleaning up. Note that this might change your mesh
53                inplace!
54    inplace :   bool
55                If False will make and return a copy of the skeleton. If True,
56                will modify the `s` inplace.
57
58    **kwargs
59                Keyword arguments are passed to the bundled function.
60
61                For `skeletor.postprocessing.drop_parallel_twigs`::
62
63                 theta :     float (default 0.01)
64                             For each twig we generate the dotproduct between the tangent
65                             vectors of it and its parents. If these line up perfectly the
66                             dotproduct will equal 1. ``theta`` determines how much that
67                             value can differ from 1 for us to still prune the twig: higher
68                             theta = more pruning.
69
70    Returns
71    -------
72    s_clean :   skeletor.Skeleton
73                Hopefully improved skeleton.
74
75    """
76    if isinstance(mesh, type(None)):
77        mesh = s.mesh
78
79    mesh = make_trimesh(mesh, validate=validate)
80
81    if not inplace:
82        s = s.copy()
83
84    # Drop parallel twigs
85    _ = drop_parallel_twigs(s, theta=kwargs.get("theta", 0.01), inplace=True)
86
87    # Recenter vertices
88    _ = recenter_vertices(s, mesh, inplace=True)
89
90    return s

Clean up the skeleton.

This function bundles a bunch of procedures to clean up the skeleton:

  1. Remove twigs that are running parallel to their parent branch
  2. Move nodes outside the mesh back inside (or at least snap to surface)

Note that this is not a magic bullet and some of this will not work (well) if the original mesh was degenerate (e.g. internal faces or not watertight) to begin with.

Parameters
  • s (skeletor.Skeleton): Skeleton to clean up.
  • mesh (trimesh.Trimesh, optional): Original mesh (e.g. before contraction). If not provided will use the mesh associated with s.
  • validate (bool): If True, will try to fix potential issues with the mesh (e.g. infinite values, duplicate vertices, degenerate faces) before cleaning up. Note that this might change your mesh inplace!
  • inplace (bool): If False will make and return a copy of the skeleton. If True, will modify the s inplace.
  • **kwargs: Keyword arguments are passed to the bundled function.

For skeletor.postprocessing.drop_parallel_twigs::

theta : float (default 0.01) For each twig we generate the dotproduct between the tangent vectors of it and its parents. If these line up perfectly the dotproduct will equal 1. theta determines how much that value can differ from 1 for us to still prune the twig: higher theta = more pruning.

Returns
def smooth( s, window: int = 3, to_smooth: list = ['x', 'y', 'z'], inplace: bool = False):
839def smooth(
840    s, window: int = 3, to_smooth: list = ["x", "y", "z"], inplace: bool = False
841):
842    """Smooth skeleton using rolling windows.
843
844    Parameters
845    ----------
846    s :             skeletor.Skeleton
847                    Skeleton to be processed.
848    window :        int, optional
849                    Size (N observations) of the rolling window in number of
850                    nodes.
851    to_smooth :     list
852                    Columns of the node table to smooth. Should work with any
853                    numeric column (e.g. 'radius').
854    inplace :       bool
855                    If False will make and return a copy of the skeleton. If
856                    True, will modify the `s` inplace.
857
858    Returns
859    -------
860    s :             skeletor.Skeleton
861                    Skeleton with smoothed node table.
862
863    """
864    if not inplace:
865        s = s.copy()
866
867    # Prepare nodes (add parent_dist for later, set index)
868    nodes = s.swc.set_index("node_id", inplace=False).copy()
869
870    to_smooth = np.array(to_smooth)
871    miss = to_smooth[~np.isin(to_smooth, nodes.columns)]
872    if len(miss):
873        raise ValueError(f"Column(s) not found in node table: {miss}")
874
875    # Go over each segment and smooth
876    for seg in s.get_segments():
877        # Get this segment's parent distances and get cumsum
878        this_co = nodes.loc[seg, to_smooth]
879
880        interp = this_co.rolling(window, min_periods=1).mean()
881
882        nodes.loc[seg, to_smooth] = interp.values
883
884    # Reassign nodes
885    s.swc = nodes.reset_index(drop=False, inplace=False)
886
887    return s

Smooth skeleton using rolling windows.

Parameters
  • s (skeletor.Skeleton): Skeleton to be processed.
  • window (int, optional): Size (N observations) of the rolling window in number of nodes.
  • to_smooth (list): Columns of the node table to smooth. Should work with any numeric column (e.g. 'radius').
  • inplace (bool): If False will make and return a copy of the skeleton. If True, will modify the s inplace.
Returns
def despike(s, sigma=5, max_spike_length=1, inplace=False, reverse=False):
890def despike(s, sigma=5, max_spike_length=1, inplace=False, reverse=False):
891    r"""Remove spikes in skeleton.
892
893    For each node A, the euclidean distance to its next successor (parent)
894    B and that node's successor C (i.e A->B->C) is computed. If
895    :math:`\\frac{dist(A,B)}{dist(A,C)}>sigma`, node B is considered a spike
896    and realigned between A and C.
897
898    Parameters
899    ----------
900    x :                 skeletor.Skeleton
901                        Skeleton to be processed.
902    sigma :             float | int, optional
903                        Threshold for spike detection. Smaller sigma = more
904                        aggressive spike detection.
905    max_spike_length :  int, optional
906                        Determines how long (# of nodes) a spike can be.
907    inplace :           bool, optional
908                        If False, a copy of the neuron is returned.
909    reverse :           bool, optional
910                        If True, will **also** walk the segments from proximal
911                        to distal. Use this to catch spikes on e.g. terminal
912                        nodes.
913
914    Returns
915    -------
916    s                   skeletor.Skeleton
917                        Despiked neuron.
918
919    """
920    if not inplace:
921        s = s.copy()
922
923    # Index nodes table by node ID
924    this_nodes = s.swc.set_index("node_id", inplace=False)
925
926    segments = s.get_segments()
927    segs_to_walk = segments
928
929    if reverse:
930        segs_to_walk += segs_to_walk[::-1]
931
932    # For each spike length do -> do this in reverse to correct the long
933    # spikes first
934    for l in list(range(1, max_spike_length + 1))[::-1]:
935        # Go over all segments
936        for seg in segs_to_walk:
937            # Get nodes A, B and C of this segment
938            this_A = this_nodes.loc[seg[: -l - 1]]
939            this_B = this_nodes.loc[seg[l:-1]]
940            this_C = this_nodes.loc[seg[l + 1 :]]
941
942            # Get coordinates
943            A = this_A[["x", "y", "z"]].values
944            B = this_B[["x", "y", "z"]].values
945            C = this_C[["x", "y", "z"]].values
946
947            # Calculate euclidian distances A->B and A->C
948            dist_AB = np.linalg.norm(A - B, axis=1)
949            dist_AC = np.linalg.norm(A - C, axis=1)
950
951            # Get the spikes
952            spikes_ix = np.where(
953                np.divide(dist_AB, dist_AC, where=dist_AC != 0) > sigma
954            )[0]
955            spikes = this_B.iloc[spikes_ix]
956
957            if not spikes.empty:
958                # Interpolate new position(s) between A and C
959                new_positions = A[spikes_ix] + (C[spikes_ix] - A[spikes_ix]) / 2
960
961                this_nodes.loc[spikes.index, ["x", "y", "z"]] = new_positions
962
963    # Reassign node table
964    s.swc = this_nodes.reset_index(drop=False, inplace=False)
965
966    return s

Remove spikes in skeleton.

For each node A, the euclidean distance to its next successor (parent) B and that node's successor C (i.e A->B->C) is computed. If \( \frac{dist(A,B)}{dist(A,C)}>sigma \), node B is considered a spike and realigned between A and C.

Parameters
  • x (skeletor.Skeleton): Skeleton to be processed.
  • sigma (float | int, optional): Threshold for spike detection. Smaller sigma = more aggressive spike detection.
  • max_spike_length (int, optional): Determines how long (# of nodes) a spike can be.
  • inplace (bool, optional): If False, a copy of the neuron is returned.
  • reverse (bool, optional): If True, will also walk the segments from proximal to distal. Use this to catch spikes on e.g. terminal nodes.
Returns
def remove_bristles(s, mesh=None, los_only=False, inplace=False):
 93def remove_bristles(s, mesh=None, los_only=False, inplace=False):
 94    """Remove "bristles" that sometimes occurr along the backbone.
 95
 96    Works by finding terminal twigs that consist of only a single node.
 97
 98    Parameters
 99    ----------
100    s :         skeletor.Skeleton
101                Skeleton to clean up.
102    mesh :      trimesh.Trimesh, optional
103                Original mesh (e.g. before contraction). If not provided will
104                use the mesh associated with ``s``.
105    los_only :  bool
106                If True, will only remove bristles that are in line of sight of
107                their parent. If False, will remove all single-node bristles.
108    inplace :   bool
109                If False will make and return a copy of the skeleton. If True,
110                will modify the `s` inplace.
111
112    Returns
113    -------
114    s :         skeletor.Skeleton
115                Skeleton with single-node twigs removed.
116
117    """
118    if isinstance(mesh, type(None)):
119        mesh = s.mesh
120
121    # Make a copy of the skeleton
122    if not inplace:
123        s = s.copy()
124
125    # Find branch points
126    pcount = s.swc[s.swc.parent_id >= 0].groupby("parent_id").size()
127    bp = pcount[pcount > 1].index
128
129    # Find terminal twigs
130    twigs = s.swc[~s.swc.node_id.isin(s.swc.parent_id)]
131    twigs = twigs[twigs.parent_id.isin(bp)]
132
133    if twigs.empty:
134        return s
135
136    if los_only:
137        # Initialize ncollpyde Volume
138        coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)
139
140        # Remove twigs that aren't inside the volume
141        twigs = twigs[coll.contains(twigs[["x", "y", "z"]].values)]
142
143        # Generate rays between all pairs and their parents
144        sources = twigs[["x", "y", "z"]].values
145        targets = (
146            s.swc.set_index("node_id").loc[twigs.parent_id, ["x", "y", "z"]].values
147        )
148
149        # Get intersections: `ix` points to index of line segment; `loc` is the
150        #  x/y/z coordinate of the intersection and `is_backface` is True if
151        # intersection happened at the inside of a mesh
152        ix, loc, is_backface = coll.intersections(sources, targets)
153
154        # Find pairs of twigs with no intersection - i.e. with line of sight
155        los = ~np.isin(np.arange(sources.shape[0]), ix)
156
157        # To remove: have line of sight
158        to_remove = twigs[los]
159    else:
160        to_remove = twigs
161
162    s.swc = s.swc[~s.swc.node_id.isin(to_remove.node_id)].copy()
163
164    # Update the mesh map
165    mesh_map = getattr(s, "mesh_map", None)
166    if not isinstance(mesh_map, type(None)):
167        for t in to_remove.itertuples():
168            mesh_map[mesh_map == t.node_id] = t.parent_id
169
170    # Reindex nodes
171    s.reindex(inplace=True)
172
173    return s

Remove "bristles" that sometimes occurr along the backbone.

Works by finding terminal twigs that consist of only a single node.

Parameters
  • s (skeletor.Skeleton): Skeleton to clean up.
  • mesh (trimesh.Trimesh, optional): Original mesh (e.g. before contraction). If not provided will use the mesh associated with s.
  • los_only (bool): If True, will only remove bristles that are in line of sight of their parent. If False, will remove all single-node bristles.
  • inplace (bool): If False will make and return a copy of the skeleton. If True, will modify the s inplace.
Returns
def recenter_vertices(s, mesh=None, inplace=False):
176def recenter_vertices(s, mesh=None, inplace=False):
177    """Move nodes that ended up outside the mesh back inside.
178
179    Nodes can end up outside the original mesh e.g. if the mesh contraction
180    didn't do a good job (most likely because of internal/degenerate faces that
181    messed up the normals). This function rectifies this by snapping those nodes
182    nodes back to the closest vertex and then tries to move them into the
183    mesh's center. That second step is not guaranteed to work but at least you
184    won't have any more nodes outside the mesh.
185
186    Please note that if connected (!) nodes end up on the same position (i.e
187    because they snapped to the same vertex), we will collapse them.
188
189    Parameters
190    ----------
191    s :         skeletor.Skeleton
192    mesh :      trimesh.Trimesh
193                Original mesh.
194    inplace :   bool
195                If False will make and return a copy of the skeleton. If True,
196                will modify the `s` inplace.
197
198    Returns
199    -------
200    s :         skeletor.Skeleton
201                Skeleton with vertices recentered.
202
203    """
204    if isinstance(mesh, type(None)):
205        mesh = s.mesh
206
207    # Copy skeleton
208    if not inplace:
209        s = s.copy()
210
211    # Find nodes that are outside the mesh
212    coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)
213    outside = ~coll.contains(s.vertices)
214
215    # Skip if all inside
216    if not any(outside):
217        return s
218
219    # For each outside find the closest vertex
220    tree = scipy.spatial.cKDTree(mesh.vertices)
221
222    # Find nodes that are right on top of original vertices
223    dist, ix = tree.query(s.vertices[outside])
224
225    # We don't want to just snap them back to the closest vertex but try to find
226    # the center. For this we will:
227    # 1. Move each vertex inside the mesh by just a bit
228    # 2. Cast a ray along the vertices' normals and find the opposite sides of the mesh
229    # 3. Calculate the distance
230
231    # Get the closest vertex...
232    closest_vertex = mesh.vertices[ix]
233    # .. and offset the vertex positions by just a bit so they "should" be
234    # inside the mesh. In reality that doesn't always happen if the mesh is not
235    # watertight
236    vnormals = mesh.vertex_normals[ix]
237    sources = closest_vertex - vnormals
238
239    # Prepare rays to cast
240    targets = sources - vnormals * 1e4
241
242    # Cast rays
243    hit_ix, loc, is_backface = coll.intersections(sources, targets)
244
245    # center-point if ray hits, otherwise keep closest_vertex
246    final_pos = closest_vertex.copy()
247
248    if len(loc) != 0:
249        # Get half-vector
250        halfvec = np.zeros(sources.shape)
251        halfvec[hit_ix] = (loc - closest_vertex[hit_ix]) / 2
252
253        # Offset vertices
254        candidate = closest_vertex + halfvec
255
256        # Keep only those that are properly inside the mesh
257        now_inside = coll.contains(candidate)
258        final_pos[now_inside] = candidate[now_inside]
259
260    # try harder to get strictly inside
261    still_outside = ~coll.contains(final_pos)
262    if still_outside.any():
263        push = (
264            mesh.edges_unique_length.mean() / 100
265            if mesh.edges_unique_length.size
266            else 1e-4
267        )  # for high-res meshes
268        push = max(push, 1e-6)
269
270        candidate_in = closest_vertex - vnormals * push
271        candidate_out = closest_vertex + vnormals * push
272
273        in_ok = coll.contains(candidate_in)
274        out_ok = coll.contains(candidate_out)
275
276        use_in = still_outside & in_ok
277        use_out = still_outside & ~in_ok & out_ok
278
279        final_pos[use_in] = candidate_in[use_in]
280        final_pos[use_out] = candidate_out[use_out]
281
282    # Keep only those that are properly inside the mesh and fall back to the
283    # closest vertex if that's not the case
284    now_inside = coll.contains(final_pos)
285    final_pos[~now_inside] = closest_vertex[~now_inside]
286
287    # Replace coordinates
288    s.swc.loc[outside, "x"] = final_pos[:, 0]
289    s.swc.loc[outside, "y"] = final_pos[:, 1]
290    s.swc.loc[outside, "z"] = final_pos[:, 2]
291
292    # At this point we may have nodes that snapped to the same vertex and
293    # therefore end up at the same position. We will collapse those nodes
294    # - but only if they are actually connected!
295    # First find duplicate locations
296    u, i, c = np.unique(s.vertices, return_counts=True, return_inverse=True, axis=0)
297
298    # If any coordinates have counter higher 1
299    if c.max() > 1:
300        rewire = {}
301        # Find out which unique coordinates are duplicated
302        dupl = np.where(c > 1)[0]
303
304        # Go over each duplicated coordinate
305        for ix in dupl:
306            # Find the nodes on this duplicate coordinate
307            node_ix = np.where(i == ix)[0]
308
309            # Get their edges
310            edges = s.edges[np.all(np.isin(s.edges, node_ix), axis=1)]
311
312            # We will work on the graph to collapse nodes sequentially A->B->C
313            G = nx.DiGraph()
314            G.add_edges_from(edges)
315            for cc in nx.connected_components(G.to_undirected()):
316                # Root is the node without any outdegree in this subgraph
317                root = [n for n in cc if G.out_degree[n] == 0][0]
318                # We don't want to collapse into `root` because it's not actually
319                # among the nodes with the same coordinates but rather the "last"
320                # nodes parent
321                clps_into = next(G.predecessors(root))
322                # Keep track of how we need to rewire
323                rewire.update({c: clps_into for c in cc if c not in {root, clps_into}})
324
325        # Only mess with the skeleton if there were nodes to be merged
326        if rewire:
327            # Rewire
328            s.swc["parent_id"] = s.swc.parent_id.map(lambda x: rewire.get(x, x))
329
330            # Drop nodes that were collapsed
331            s.swc = s.swc.loc[~s.swc.node_id.isin(rewire)]
332
333            # Update mesh map
334            if not isinstance(s.mesh_map, type(None)):
335                s.mesh_map = [rewire.get(x, x) for x in s.mesh_map]
336
337            # Reindex to make vertex IDs continous again
338            s.reindex(inplace=True)
339
340            # This prevents future SettingsWithCopy Warnings:
341            if not inplace:
342                s.swc = s.swc.copy()
343
344    return s

Move nodes that ended up outside the mesh back inside.

Nodes can end up outside the original mesh e.g. if the mesh contraction didn't do a good job (most likely because of internal/degenerate faces that messed up the normals). This function rectifies this by snapping those nodes nodes back to the closest vertex and then tries to move them into the mesh's center. That second step is not guaranteed to work but at least you won't have any more nodes outside the mesh.

Please note that if connected (!) nodes end up on the same position (i.e because they snapped to the same vertex), we will collapse them.

Parameters
  • s (skeletor.Skeleton):

  • mesh (trimesh.Trimesh): Original mesh.

  • inplace (bool): If False will make and return a copy of the skeleton. If True, will modify the s inplace.
Returns
def fix_outside_edges(s, mesh=None, inplace=False, max_iter=8, smooth_iters=1, eps=1e-06):
347def fix_outside_edges(
348    s, mesh=None, inplace=False, max_iter=8, smooth_iters=1, eps=1e-6
349):
350    """Fix edges that cross outside the mesh boundary.
351
352    This function detects skeleton edges that intersect the mesh boundary and
353    fixes them by iteratively splitting crossing edges (inserting new vertices
354    along the edge) and then recentering any vertices that end up outside the
355    mesh using `recenter_vertices()`.
356
357    Notes
358    -----
359    This will also modify original vertices positions (via `skeletor.post.recenter_vertices()`).
360    Splitting edges inserts new skeleton nodes that are not represented in "skel_map". Currently,
361    we invalidate any existing "mesh_map" and "skel_map" (setting it to `None`). In the
362    future, we may add functionality to update the mapping instead.
363
364    Parameters
365    ----------
366    s :         skeletor.Skeleton
367    mesh :      trimesh.Trimesh
368                Original mesh. If mesh is None, will use the mesh associated with input
369                skeleton (`s.mesh`).
370    inplace :   bool
371                If False will make and return a copy of the skeleton. If True,
372                will modify the `s` inplace.
373    max_iter :  int
374                Max split iterations for boundary-crossing edges.
375    smooth_iters : int
376                Number of smoothing iterations for degree-2 chain nodes.
377    eps :       float or {'auto'}
378                Ignore intersections within eps of either endpoint.
379                If "auto", uses mesh mean unique edge length * 1e-4.
380
381    Returns
382    -------
383    s :         skeletor.Skeleton
384
385    """
386    if isinstance(mesh, type(None)):
387        mesh = s.mesh
388
389    if mesh is None:
390        raise ValueError(
391            "Mesh is required for fixing outside edges. Please provide a mesh or ensure `s.mesh` is set."
392        )
393
394    if not inplace:
395        s = s.copy()
396
397    if s.swc is None or s.swc.empty:
398        return s
399
400    # Determine eps (scale-aware)
401    if isinstance(eps, str):
402        if eps.lower() != "auto":
403            raise ValueError("Invalid value for `eps`. Must be a number or 'auto'.")
404        try:
405            mean_length = float(np.nanmean(mesh.edges_unique_length))
406        except Exception:
407            mean_length = np.nan
408        eps = (
409            mean_length * 1e-4
410            if (np.isfinite(mean_length) and mean_length > 0)
411            else 1e-6
412        )
413    else:
414        eps = float(eps)
415
416    max_iter = int(max_iter)
417    if max_iter < 0:
418        raise ValueError("`max_iter` must be >= 0")
419
420    smooth_iters = int(smooth_iters)
421    if smooth_iters < 0:
422        raise ValueError("`smooth_iters` must be >= 0")
423
424    coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)
425
426    # 1. Recenter any nodes outside the mesh
427    if (~coll.contains(s.vertices)).any():
428        recenter_vertices(s, mesh=mesh, inplace=True)
429
430    has_radius = "radius" in s.swc.columns
431
432    # 2. Iteratively split crossing edges
433    for _ in range(max_iter):
434        swc = s.swc
435        edge_rows = np.where(swc.parent_id.values >= 0)[0]
436        if edge_rows.size == 0:
437            break
438
439        sources = swc.loc[edge_rows, ["x", "y", "z"]].values
440        parent_ids = swc.loc[edge_rows, "parent_id"].values
441        targets = swc.set_index("node_id").loc[parent_ids, ["x", "y", "z"]].values
442
443        ix, loc, _ = coll.intersections(sources, targets)
444
445        crossing = np.zeros(edge_rows.shape[0], dtype=bool)
446
447        if len(ix):
448            d_src = np.linalg.norm(loc - sources[ix], axis=1)
449            d_tgt = np.linalg.norm(loc - targets[ix], axis=1)
450            real_crossings = (d_src > eps) & (d_tgt > eps)
451            if real_crossings.any():
452                crossing[np.unique(ix[real_crossings])] = True
453
454        if not crossing.any():
455            break
456
457        to_split = edge_rows[crossing]
458        next_node_id = int(swc.node_id.max()) + 1
459
460        new_rows = []
461        nodes = swc.set_index("node_id")
462
463        # Cache arrays for cheap positional access inside the loop
464        parent_id_arr = swc["parent_id"].to_numpy(copy=True)
465        parent_col = swc.columns.get_loc("parent_id")
466
467        xyz_arr = swc[["x", "y", "z"]].to_numpy(copy=False)
468        if not np.issubdtype(xyz_arr.dtype, np.number):
469            xyz_arr = xyz_arr.astype(float)
470
471        for edge_row in to_split:
472            # edge_row is a positional row index (from np.where)
473            parent_id = int(parent_id_arr[edge_row])
474            if parent_id < 0:
475                continue
476
477            # Midpoint (no projection; recenter will handle)
478            child_co = xyz_arr[edge_row].astype(float, copy=False)
479            parent_co = nodes.loc[parent_id, ["x", "y", "z"]].values.astype(
480                float, copy=False
481            )
482            midpoint = (child_co + parent_co) / 2.0
483
484            # Rewire child -> new node (positional write; avoids boolean mask on node_id)
485            swc.iat[edge_row, parent_col] = next_node_id
486            parent_id_arr[edge_row] = next_node_id  # keep cached view consistent
487
488            # Create new node row
489            row = {col: np.nan for col in swc.columns}
490            row["node_id"] = next_node_id
491            row["parent_id"] = parent_id
492            row["x"], row["y"], row["z"] = midpoint
493
494            if has_radius:
495                child_r = pd.to_numeric(
496                    pd.Series([swc.at[edge_row, "radius"]]), errors="coerce"
497                ).iloc[0]
498                parent_r = pd.to_numeric(
499                    pd.Series([nodes.loc[parent_id, "radius"]]), errors="coerce"
500                ).iloc[0]
501                row["radius"] = np.nanmean(np.array([child_r, parent_r], dtype=float))
502
503            new_rows.append(row)
504            next_node_id += 1
505
506        if not new_rows:
507            break
508
509        swc = pd.concat(
510            [swc, pd.DataFrame(new_rows, columns=swc.columns)], ignore_index=True
511        )
512        s.swc = swc
513
514        coords = s.swc[["x", "y", "z"]].values
515        if (~coll.contains(coords)).any():
516            recenter_vertices(s, mesh=mesh, inplace=True)
517
518    # 3. Smoothing (degree-2 chain nodes), then recenter
519    for _ in range(smooth_iters):
520        swc = s.swc
521
522        child_counts = swc[swc.parent_id >= 0].groupby("parent_id").size()
523        is_chain = (swc.parent_id >= 0) & (
524            swc.node_id.map(child_counts).fillna(0).astype(int) == 1
525        )
526        chain_nodes = swc.loc[is_chain, "node_id"].values.astype(int)
527
528        if chain_nodes.size == 0:
529            break
530
531        only_child = (
532            swc[swc.parent_id >= 0].groupby("parent_id").node_id.first().to_dict()
533        )
534        nodes = swc.set_index("node_id")
535
536        parent_ids = nodes.loc[chain_nodes, "parent_id"].values.astype(int)
537        child_ids = np.array([only_child[n] for n in chain_nodes], dtype=int)
538
539        parent_co = nodes.loc[parent_ids, ["x", "y", "z"]].values.astype(float)
540        child_co = nodes.loc[child_ids, ["x", "y", "z"]].values.astype(float)
541        smoothed = (parent_co + child_co) / 2.0
542
543        swc.loc[is_chain, ["x", "y", "z"]] = smoothed
544        s.swc = swc
545
546        coords = s.swc[["x", "y", "z"]].values
547        if (~coll.contains(coords)).any():
548            recenter_vertices(s, mesh=mesh, inplace=True)
549
550    swc = s.swc
551    edge_rows = np.where(swc.parent_id.values >= 0)[0]
552    remaining_crossings = 0
553    # Detect crossing edges again for double-checking
554    if edge_rows.size:
555        sources = swc.loc[edge_rows, ["x", "y", "z"]].values
556        parent_ids = swc.loc[edge_rows, "parent_id"].values
557        nodes = swc.set_index("node_id")
558        try:
559            targets = nodes.loc[parent_ids, ["x", "y", "z"]].values
560            ix, loc, _ = coll.intersections(sources, targets)
561            if len(ix):
562                d_src = np.linalg.norm(loc - sources[ix], axis=1)
563                d_tgt = np.linalg.norm(loc - targets[ix], axis=1)
564                real = (d_src > eps) & (d_tgt > eps)
565                if np.any(real):
566                    crossing = np.zeros(edge_rows.shape[0], dtype=bool)
567                    crossing[np.unique(ix[real])] = True
568                    remaining_crossings = int(crossing.sum())
569        except KeyError:
570            remaining_crossings = 0
571
572    if remaining_crossings > 0:
573        warnings.warn(
574            f"{remaining_crossings} crossing edges remain after {max_iter} "
575            "fix iteration(s); returning best-effort result. Consider increasing "
576            "`max_iter`, adjusting `eps`, and/or running `post.clean_up` / "
577            "`post.remove_bristles` first. Also check mesh quality (e.g. non-watertight "
578            "or degenerate faces).",
579            RuntimeWarning,
580        )
581
582    # Invalidate mesh_map
583    s.mesh_map = None
584
585    return s

Fix edges that cross outside the mesh boundary.

This function detects skeleton edges that intersect the mesh boundary and fixes them by iteratively splitting crossing edges (inserting new vertices along the edge) and then recentering any vertices that end up outside the mesh using recenter_vertices().

Notes

This will also modify original vertices positions (via skeletor.post.recenter_vertices()). Splitting edges inserts new skeleton nodes that are not represented in "skel_map". Currently, we invalidate any existing "mesh_map" and "skel_map" (setting it to None). In the future, we may add functionality to update the mapping instead.

Parameters
  • s (skeletor.Skeleton):

  • mesh (trimesh.Trimesh): Original mesh. If mesh is None, will use the mesh associated with input skeleton (s.mesh).

  • inplace (bool): If False will make and return a copy of the skeleton. If True, will modify the s inplace.
  • max_iter (int): Max split iterations for boundary-crossing edges.
  • smooth_iters (int): Number of smoothing iterations for degree-2 chain nodes.
  • eps (float or {'auto'}): Ignore intersections within eps of either endpoint. If "auto", uses mesh mean unique edge length * 1e-4.
Returns