skeletor.post
The skeletor.post module contains functions to post-process skeletons after
skeletonization.
Fixing issues with skeletons
Depending on your mesh, pre-processing and the parameters you chose for skeletonization, chances are that your skeleton will not come out perfectly.
skeletor.post.clean_up can help you solve some potential issues:
- skeleton nodes (vertices) that outside or right on the surface instead of centered inside the mesh
- superfluous "hairs" on otherwise straight bits
skeletor.post.smooth will smooth out the skeleton.
skeletor.post.despike can help you remove spikes in the skeleton where
single nodes are out of aligment.
skeletor.post.remove_bristles will remove bristles from the skeleton.
Computing radius information
Only skeletor.skeletonize.by_wavefront() provides radii off the bat. For all
other methods, you might want to run skeletor.post.radii can help you
(re-)generate radius information for the skeletons.
1# This script is part of skeletor (http://www.github.com/navis-org/skeletor). 2# Copyright (C) 2018 Philipp Schlegel 3# 4# This program is free software: you can redistribute it and/or modify 5# it under the terms of the GNU General Public License as published by 6# the Free Software Foundation, either version 3 of the License, or 7# (at your option) any later version. 8# 9# This program is distributed in the hope that it will be useful, 10# but WITHOUT ANY WARRANTY; without even the implied warranty of 11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12# GNU General Public License for more details. 13# 14# You should have received a copy of the GNU General Public License 15# along with this program. 16 17r""" 18The `skeletor.post` module contains functions to post-process skeletons after 19skeletonization. 20 21### Fixing issues with skeletons 22 23Depending on your mesh, pre-processing and the parameters you chose for 24skeletonization, chances are that your skeleton will not come out perfectly. 25 26`skeletor.post.clean_up` can help you solve some potential issues: 27 28- skeleton nodes (vertices) that outside or right on the surface instead of 29 centered inside the mesh 30- superfluous "hairs" on otherwise straight bits 31 32`skeletor.post.smooth` will smooth out the skeleton. 33 34`skeletor.post.despike` can help you remove spikes in the skeleton where 35single nodes are out of aligment. 36 37`skeletor.post.remove_bristles` will remove bristles from the skeleton. 38 39### Computing radius information 40 41Only `skeletor.skeletonize.by_wavefront()` provides radii off the bat. For all 42other methods, you might want to run `skeletor.post.radii` can help you 43(re-)generate radius information for the skeletons. 44 45""" 46 47from .radiusextraction import radii 48from .postprocessing import clean_up, smooth, despike, remove_bristles, recenter_vertices, fix_outside_edges 49 50__docformat__ = "numpy" 51__all__ = ["radii", "clean_up", "smooth", "despike", "remove_bristles", "recenter_vertices", "fix_outside_edges"]
29def radii(s, mesh=None, method='knn', aggregate='mean', validate=False, **kwargs): 30 """Extract radii for given skeleton table. 31 32 Important 33 --------- 34 This function really only produces useful radii if the skeleton is centered 35 inside the mesh. `by_wavefront` does that by default whereas all other 36 skeletonization methods don't. Your best bet to get centered skeletons is 37 to contract the mesh first (`sk.pre.contract`). 38 39 Parameters 40 ---------- 41 s : skeletor.Skeleton 42 Skeleton to clean up. 43 mesh : trimesh.Trimesh, optional 44 Original mesh (e.g. before contraction). If not provided will 45 use the mesh associated with ``s``. 46 method : "knn" | "ray" 47 Whether and how to add radius information to each node:: 48 49 - "knn" uses k-nearest-neighbors to get radii: fast but 50 potential for being very wrong 51 - "ray" uses ray-casting to get radii: slower but sometimes 52 less wrong 53 54 aggregate : "mean" | "median" | "max" | "min" | "percentile75" 55 Function used to aggregate radii over sample (i.e. across 56 k nearest-neighbors or ray intersections) 57 validate : bool 58 If True, will try to fix potential issues with the mesh 59 (e.g. infinite values, duplicate vertices, degenerate faces) 60 before skeletonization. Note that this might make changes to 61 your mesh inplace! 62 **kwargs 63 Keyword arguments are passed to the respective method: 64 65 For method "knn":: 66 67 n : int (default 5) 68 Radius will be the mean over n nearest-neighbors. 69 70 For method "ray":: 71 72 n_rays : int (default 20) 73 Number of rays to cast for each node. 74 projection : "sphere" (default) | "tangents" 75 Whether to cast rays in a sphere around each node or in a 76 circle orthogonally to the node's tangent vector. 77 fallback : "knn" (default) | None | number 78 If a point is outside or right on the surface of the mesh 79 the raycasting will return nonesense results. We can either 80 ignore those cases (``None``), assign a arbitrary number or 81 we can fall back to radii from k-nearest-neighbors (``knn``). 82 83 Returns 84 ------- 85 None 86 But attaches `radius` to the skeleton's SWC table. Existing 87 values are replaced! 88 89 """ 90 if isinstance(mesh, type(None)): 91 mesh = s.mesh 92 93 mesh = make_trimesh(mesh, validate=True) 94 95 if method == 'knn': 96 radius = get_radius_knn(s.swc[['x', 'y', 'z']].values, 97 aggregate=aggregate, 98 mesh=mesh, **kwargs) 99 elif method == 'ray': 100 radius = get_radius_ray(s.swc, 101 mesh=mesh, 102 aggregate=aggregate, 103 **kwargs) 104 else: 105 raise ValueError(f'Unknown method "{method}"') 106 107 s.swc['radius'] = radius 108 109 return
Extract radii for given skeleton table.
Important
This function really only produces useful radii if the skeleton is centered
inside the mesh. by_wavefront does that by default whereas all other
skeletonization methods don't. Your best bet to get centered skeletons is
to contract the mesh first (sk.pre.contract).
Parameters
- s (skeletor.Skeleton): Skeleton to clean up.
- mesh (trimesh.Trimesh, optional):
Original mesh (e.g. before contraction). If not provided will
use the mesh associated with
s. method ("knn" | "ray"): Whether and how to add radius information to each node::
- "knn" uses k-nearest-neighbors to get radii: fast but potential for being very wrong - "ray" uses ray-casting to get radii: slower but sometimes less wrong- aggregate ("mean" | "median" | "max" | "min" | "percentile75"): Function used to aggregate radii over sample (i.e. across k nearest-neighbors or ray intersections)
- validate (bool): If True, will try to fix potential issues with the mesh (e.g. infinite values, duplicate vertices, degenerate faces) before skeletonization. Note that this might make changes to your mesh inplace!
- **kwargs: Keyword arguments are passed to the respective method:
For method "knn"::
n : int (default 5)
Radius will be the mean over n nearest-neighbors.
For method "ray"::
n_rays : int (default 20)
Number of rays to cast for each node.
projection : "sphere" (default) | "tangents"
Whether to cast rays in a sphere around each node or in a
circle orthogonally to the node's tangent vector.
fallback : "knn" (default) | None | number
If a point is outside or right on the surface of the mesh
the raycasting will return nonesense results. We can either
ignore those cases (``None``), assign a arbitrary number or
we can fall back to radii from k-nearest-neighbors (``knn``).
Returns
- None: But attaches
radiusto the skeleton's SWC table. Existing values are replaced!
30def clean_up(s, mesh=None, validate=False, inplace=False, **kwargs): 31 """Clean up the skeleton. 32 33 This function bundles a bunch of procedures to clean up the skeleton: 34 35 1. Remove twigs that are running parallel to their parent branch 36 2. Move nodes outside the mesh back inside (or at least snap to surface) 37 38 Note that this is not a magic bullet and some of this will not work (well) 39 if the original mesh was degenerate (e.g. internal faces or not watertight) 40 to begin with. 41 42 Parameters 43 ---------- 44 s : skeletor.Skeleton 45 Skeleton to clean up. 46 mesh : trimesh.Trimesh, optional 47 Original mesh (e.g. before contraction). If not provided will 48 use the mesh associated with ``s``. 49 validate : bool 50 If True, will try to fix potential issues with the mesh 51 (e.g. infinite values, duplicate vertices, degenerate faces) 52 before cleaning up. Note that this might change your mesh 53 inplace! 54 inplace : bool 55 If False will make and return a copy of the skeleton. If True, 56 will modify the `s` inplace. 57 58 **kwargs 59 Keyword arguments are passed to the bundled function. 60 61 For `skeletor.postprocessing.drop_parallel_twigs`:: 62 63 theta : float (default 0.01) 64 For each twig we generate the dotproduct between the tangent 65 vectors of it and its parents. If these line up perfectly the 66 dotproduct will equal 1. ``theta`` determines how much that 67 value can differ from 1 for us to still prune the twig: higher 68 theta = more pruning. 69 70 Returns 71 ------- 72 s_clean : skeletor.Skeleton 73 Hopefully improved skeleton. 74 75 """ 76 if isinstance(mesh, type(None)): 77 mesh = s.mesh 78 79 mesh = make_trimesh(mesh, validate=validate) 80 81 if not inplace: 82 s = s.copy() 83 84 # Drop parallel twigs 85 _ = drop_parallel_twigs(s, theta=kwargs.get("theta", 0.01), inplace=True) 86 87 # Recenter vertices 88 _ = recenter_vertices(s, mesh, inplace=True) 89 90 return s
Clean up the skeleton.
This function bundles a bunch of procedures to clean up the skeleton:
- Remove twigs that are running parallel to their parent branch
- Move nodes outside the mesh back inside (or at least snap to surface)
Note that this is not a magic bullet and some of this will not work (well) if the original mesh was degenerate (e.g. internal faces or not watertight) to begin with.
Parameters
- s (skeletor.Skeleton): Skeleton to clean up.
- mesh (trimesh.Trimesh, optional):
Original mesh (e.g. before contraction). If not provided will
use the mesh associated with
s. - validate (bool): If True, will try to fix potential issues with the mesh (e.g. infinite values, duplicate vertices, degenerate faces) before cleaning up. Note that this might change your mesh inplace!
- inplace (bool):
If False will make and return a copy of the skeleton. If True,
will modify the
sinplace. - **kwargs: Keyword arguments are passed to the bundled function.
For skeletor.postprocessing.drop_parallel_twigs::
theta : float (default 0.01)
For each twig we generate the dotproduct between the tangent
vectors of it and its parents. If these line up perfectly the
dotproduct will equal 1. theta determines how much that
value can differ from 1 for us to still prune the twig: higher
theta = more pruning.
Returns
- s_clean (skeletor.Skeleton): Hopefully improved skeleton.
839def smooth( 840 s, window: int = 3, to_smooth: list = ["x", "y", "z"], inplace: bool = False 841): 842 """Smooth skeleton using rolling windows. 843 844 Parameters 845 ---------- 846 s : skeletor.Skeleton 847 Skeleton to be processed. 848 window : int, optional 849 Size (N observations) of the rolling window in number of 850 nodes. 851 to_smooth : list 852 Columns of the node table to smooth. Should work with any 853 numeric column (e.g. 'radius'). 854 inplace : bool 855 If False will make and return a copy of the skeleton. If 856 True, will modify the `s` inplace. 857 858 Returns 859 ------- 860 s : skeletor.Skeleton 861 Skeleton with smoothed node table. 862 863 """ 864 if not inplace: 865 s = s.copy() 866 867 # Prepare nodes (add parent_dist for later, set index) 868 nodes = s.swc.set_index("node_id", inplace=False).copy() 869 870 to_smooth = np.array(to_smooth) 871 miss = to_smooth[~np.isin(to_smooth, nodes.columns)] 872 if len(miss): 873 raise ValueError(f"Column(s) not found in node table: {miss}") 874 875 # Go over each segment and smooth 876 for seg in s.get_segments(): 877 # Get this segment's parent distances and get cumsum 878 this_co = nodes.loc[seg, to_smooth] 879 880 interp = this_co.rolling(window, min_periods=1).mean() 881 882 nodes.loc[seg, to_smooth] = interp.values 883 884 # Reassign nodes 885 s.swc = nodes.reset_index(drop=False, inplace=False) 886 887 return s
Smooth skeleton using rolling windows.
Parameters
- s (skeletor.Skeleton): Skeleton to be processed.
- window (int, optional): Size (N observations) of the rolling window in number of nodes.
- to_smooth (list): Columns of the node table to smooth. Should work with any numeric column (e.g. 'radius').
- inplace (bool):
If False will make and return a copy of the skeleton. If
True, will modify the
sinplace.
Returns
- s (skeletor.Skeleton): Skeleton with smoothed node table.
890def despike(s, sigma=5, max_spike_length=1, inplace=False, reverse=False): 891 r"""Remove spikes in skeleton. 892 893 For each node A, the euclidean distance to its next successor (parent) 894 B and that node's successor C (i.e A->B->C) is computed. If 895 :math:`\\frac{dist(A,B)}{dist(A,C)}>sigma`, node B is considered a spike 896 and realigned between A and C. 897 898 Parameters 899 ---------- 900 x : skeletor.Skeleton 901 Skeleton to be processed. 902 sigma : float | int, optional 903 Threshold for spike detection. Smaller sigma = more 904 aggressive spike detection. 905 max_spike_length : int, optional 906 Determines how long (# of nodes) a spike can be. 907 inplace : bool, optional 908 If False, a copy of the neuron is returned. 909 reverse : bool, optional 910 If True, will **also** walk the segments from proximal 911 to distal. Use this to catch spikes on e.g. terminal 912 nodes. 913 914 Returns 915 ------- 916 s skeletor.Skeleton 917 Despiked neuron. 918 919 """ 920 if not inplace: 921 s = s.copy() 922 923 # Index nodes table by node ID 924 this_nodes = s.swc.set_index("node_id", inplace=False) 925 926 segments = s.get_segments() 927 segs_to_walk = segments 928 929 if reverse: 930 segs_to_walk += segs_to_walk[::-1] 931 932 # For each spike length do -> do this in reverse to correct the long 933 # spikes first 934 for l in list(range(1, max_spike_length + 1))[::-1]: 935 # Go over all segments 936 for seg in segs_to_walk: 937 # Get nodes A, B and C of this segment 938 this_A = this_nodes.loc[seg[: -l - 1]] 939 this_B = this_nodes.loc[seg[l:-1]] 940 this_C = this_nodes.loc[seg[l + 1 :]] 941 942 # Get coordinates 943 A = this_A[["x", "y", "z"]].values 944 B = this_B[["x", "y", "z"]].values 945 C = this_C[["x", "y", "z"]].values 946 947 # Calculate euclidian distances A->B and A->C 948 dist_AB = np.linalg.norm(A - B, axis=1) 949 dist_AC = np.linalg.norm(A - C, axis=1) 950 951 # Get the spikes 952 spikes_ix = np.where( 953 np.divide(dist_AB, dist_AC, where=dist_AC != 0) > sigma 954 )[0] 955 spikes = this_B.iloc[spikes_ix] 956 957 if not spikes.empty: 958 # Interpolate new position(s) between A and C 959 new_positions = A[spikes_ix] + (C[spikes_ix] - A[spikes_ix]) / 2 960 961 this_nodes.loc[spikes.index, ["x", "y", "z"]] = new_positions 962 963 # Reassign node table 964 s.swc = this_nodes.reset_index(drop=False, inplace=False) 965 966 return s
Remove spikes in skeleton.
For each node A, the euclidean distance to its next successor (parent) B and that node's successor C (i.e A->B->C) is computed. If \( \frac{dist(A,B)}{dist(A,C)}>sigma \), node B is considered a spike and realigned between A and C.
Parameters
- x (skeletor.Skeleton): Skeleton to be processed.
- sigma (float | int, optional): Threshold for spike detection. Smaller sigma = more aggressive spike detection.
- max_spike_length (int, optional): Determines how long (# of nodes) a spike can be.
- inplace (bool, optional): If False, a copy of the neuron is returned.
- reverse (bool, optional): If True, will also walk the segments from proximal to distal. Use this to catch spikes on e.g. terminal nodes.
Returns
- s skeletor.Skeleton: Despiked neuron.
93def remove_bristles(s, mesh=None, los_only=False, inplace=False): 94 """Remove "bristles" that sometimes occurr along the backbone. 95 96 Works by finding terminal twigs that consist of only a single node. 97 98 Parameters 99 ---------- 100 s : skeletor.Skeleton 101 Skeleton to clean up. 102 mesh : trimesh.Trimesh, optional 103 Original mesh (e.g. before contraction). If not provided will 104 use the mesh associated with ``s``. 105 los_only : bool 106 If True, will only remove bristles that are in line of sight of 107 their parent. If False, will remove all single-node bristles. 108 inplace : bool 109 If False will make and return a copy of the skeleton. If True, 110 will modify the `s` inplace. 111 112 Returns 113 ------- 114 s : skeletor.Skeleton 115 Skeleton with single-node twigs removed. 116 117 """ 118 if isinstance(mesh, type(None)): 119 mesh = s.mesh 120 121 # Make a copy of the skeleton 122 if not inplace: 123 s = s.copy() 124 125 # Find branch points 126 pcount = s.swc[s.swc.parent_id >= 0].groupby("parent_id").size() 127 bp = pcount[pcount > 1].index 128 129 # Find terminal twigs 130 twigs = s.swc[~s.swc.node_id.isin(s.swc.parent_id)] 131 twigs = twigs[twigs.parent_id.isin(bp)] 132 133 if twigs.empty: 134 return s 135 136 if los_only: 137 # Initialize ncollpyde Volume 138 coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False) 139 140 # Remove twigs that aren't inside the volume 141 twigs = twigs[coll.contains(twigs[["x", "y", "z"]].values)] 142 143 # Generate rays between all pairs and their parents 144 sources = twigs[["x", "y", "z"]].values 145 targets = ( 146 s.swc.set_index("node_id").loc[twigs.parent_id, ["x", "y", "z"]].values 147 ) 148 149 # Get intersections: `ix` points to index of line segment; `loc` is the 150 # x/y/z coordinate of the intersection and `is_backface` is True if 151 # intersection happened at the inside of a mesh 152 ix, loc, is_backface = coll.intersections(sources, targets) 153 154 # Find pairs of twigs with no intersection - i.e. with line of sight 155 los = ~np.isin(np.arange(sources.shape[0]), ix) 156 157 # To remove: have line of sight 158 to_remove = twigs[los] 159 else: 160 to_remove = twigs 161 162 s.swc = s.swc[~s.swc.node_id.isin(to_remove.node_id)].copy() 163 164 # Update the mesh map 165 mesh_map = getattr(s, "mesh_map", None) 166 if not isinstance(mesh_map, type(None)): 167 for t in to_remove.itertuples(): 168 mesh_map[mesh_map == t.node_id] = t.parent_id 169 170 # Reindex nodes 171 s.reindex(inplace=True) 172 173 return s
Remove "bristles" that sometimes occurr along the backbone.
Works by finding terminal twigs that consist of only a single node.
Parameters
- s (skeletor.Skeleton): Skeleton to clean up.
- mesh (trimesh.Trimesh, optional):
Original mesh (e.g. before contraction). If not provided will
use the mesh associated with
s. - los_only (bool): If True, will only remove bristles that are in line of sight of their parent. If False, will remove all single-node bristles.
- inplace (bool):
If False will make and return a copy of the skeleton. If True,
will modify the
sinplace.
Returns
- s (skeletor.Skeleton): Skeleton with single-node twigs removed.
176def recenter_vertices(s, mesh=None, inplace=False): 177 """Move nodes that ended up outside the mesh back inside. 178 179 Nodes can end up outside the original mesh e.g. if the mesh contraction 180 didn't do a good job (most likely because of internal/degenerate faces that 181 messed up the normals). This function rectifies this by snapping those nodes 182 nodes back to the closest vertex and then tries to move them into the 183 mesh's center. That second step is not guaranteed to work but at least you 184 won't have any more nodes outside the mesh. 185 186 Please note that if connected (!) nodes end up on the same position (i.e 187 because they snapped to the same vertex), we will collapse them. 188 189 Parameters 190 ---------- 191 s : skeletor.Skeleton 192 mesh : trimesh.Trimesh 193 Original mesh. 194 inplace : bool 195 If False will make and return a copy of the skeleton. If True, 196 will modify the `s` inplace. 197 198 Returns 199 ------- 200 s : skeletor.Skeleton 201 Skeleton with vertices recentered. 202 203 """ 204 if isinstance(mesh, type(None)): 205 mesh = s.mesh 206 207 # Copy skeleton 208 if not inplace: 209 s = s.copy() 210 211 # Find nodes that are outside the mesh 212 coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False) 213 outside = ~coll.contains(s.vertices) 214 215 # Skip if all inside 216 if not any(outside): 217 return s 218 219 # For each outside find the closest vertex 220 tree = scipy.spatial.cKDTree(mesh.vertices) 221 222 # Find nodes that are right on top of original vertices 223 dist, ix = tree.query(s.vertices[outside]) 224 225 # We don't want to just snap them back to the closest vertex but try to find 226 # the center. For this we will: 227 # 1. Move each vertex inside the mesh by just a bit 228 # 2. Cast a ray along the vertices' normals and find the opposite sides of the mesh 229 # 3. Calculate the distance 230 231 # Get the closest vertex... 232 closest_vertex = mesh.vertices[ix] 233 # .. and offset the vertex positions by just a bit so they "should" be 234 # inside the mesh. In reality that doesn't always happen if the mesh is not 235 # watertight 236 vnormals = mesh.vertex_normals[ix] 237 sources = closest_vertex - vnormals 238 239 # Prepare rays to cast 240 targets = sources - vnormals * 1e4 241 242 # Cast rays 243 hit_ix, loc, is_backface = coll.intersections(sources, targets) 244 245 # center-point if ray hits, otherwise keep closest_vertex 246 final_pos = closest_vertex.copy() 247 248 if len(loc) != 0: 249 # Get half-vector 250 halfvec = np.zeros(sources.shape) 251 halfvec[hit_ix] = (loc - closest_vertex[hit_ix]) / 2 252 253 # Offset vertices 254 candidate = closest_vertex + halfvec 255 256 # Keep only those that are properly inside the mesh 257 now_inside = coll.contains(candidate) 258 final_pos[now_inside] = candidate[now_inside] 259 260 # try harder to get strictly inside 261 still_outside = ~coll.contains(final_pos) 262 if still_outside.any(): 263 push = ( 264 mesh.edges_unique_length.mean() / 100 265 if mesh.edges_unique_length.size 266 else 1e-4 267 ) # for high-res meshes 268 push = max(push, 1e-6) 269 270 candidate_in = closest_vertex - vnormals * push 271 candidate_out = closest_vertex + vnormals * push 272 273 in_ok = coll.contains(candidate_in) 274 out_ok = coll.contains(candidate_out) 275 276 use_in = still_outside & in_ok 277 use_out = still_outside & ~in_ok & out_ok 278 279 final_pos[use_in] = candidate_in[use_in] 280 final_pos[use_out] = candidate_out[use_out] 281 282 # Keep only those that are properly inside the mesh and fall back to the 283 # closest vertex if that's not the case 284 now_inside = coll.contains(final_pos) 285 final_pos[~now_inside] = closest_vertex[~now_inside] 286 287 # Replace coordinates 288 s.swc.loc[outside, "x"] = final_pos[:, 0] 289 s.swc.loc[outside, "y"] = final_pos[:, 1] 290 s.swc.loc[outside, "z"] = final_pos[:, 2] 291 292 # At this point we may have nodes that snapped to the same vertex and 293 # therefore end up at the same position. We will collapse those nodes 294 # - but only if they are actually connected! 295 # First find duplicate locations 296 u, i, c = np.unique(s.vertices, return_counts=True, return_inverse=True, axis=0) 297 298 # If any coordinates have counter higher 1 299 if c.max() > 1: 300 rewire = {} 301 # Find out which unique coordinates are duplicated 302 dupl = np.where(c > 1)[0] 303 304 # Go over each duplicated coordinate 305 for ix in dupl: 306 # Find the nodes on this duplicate coordinate 307 node_ix = np.where(i == ix)[0] 308 309 # Get their edges 310 edges = s.edges[np.all(np.isin(s.edges, node_ix), axis=1)] 311 312 # We will work on the graph to collapse nodes sequentially A->B->C 313 G = nx.DiGraph() 314 G.add_edges_from(edges) 315 for cc in nx.connected_components(G.to_undirected()): 316 # Root is the node without any outdegree in this subgraph 317 root = [n for n in cc if G.out_degree[n] == 0][0] 318 # We don't want to collapse into `root` because it's not actually 319 # among the nodes with the same coordinates but rather the "last" 320 # nodes parent 321 clps_into = next(G.predecessors(root)) 322 # Keep track of how we need to rewire 323 rewire.update({c: clps_into for c in cc if c not in {root, clps_into}}) 324 325 # Only mess with the skeleton if there were nodes to be merged 326 if rewire: 327 # Rewire 328 s.swc["parent_id"] = s.swc.parent_id.map(lambda x: rewire.get(x, x)) 329 330 # Drop nodes that were collapsed 331 s.swc = s.swc.loc[~s.swc.node_id.isin(rewire)] 332 333 # Update mesh map 334 if not isinstance(s.mesh_map, type(None)): 335 s.mesh_map = [rewire.get(x, x) for x in s.mesh_map] 336 337 # Reindex to make vertex IDs continous again 338 s.reindex(inplace=True) 339 340 # This prevents future SettingsWithCopy Warnings: 341 if not inplace: 342 s.swc = s.swc.copy() 343 344 return s
Move nodes that ended up outside the mesh back inside.
Nodes can end up outside the original mesh e.g. if the mesh contraction didn't do a good job (most likely because of internal/degenerate faces that messed up the normals). This function rectifies this by snapping those nodes nodes back to the closest vertex and then tries to move them into the mesh's center. That second step is not guaranteed to work but at least you won't have any more nodes outside the mesh.
Please note that if connected (!) nodes end up on the same position (i.e because they snapped to the same vertex), we will collapse them.
Parameters
s (skeletor.Skeleton):
mesh (trimesh.Trimesh): Original mesh.
- inplace (bool):
If False will make and return a copy of the skeleton. If True,
will modify the
sinplace.
Returns
- s (skeletor.Skeleton): Skeleton with vertices recentered.
347def fix_outside_edges( 348 s, mesh=None, inplace=False, max_iter=8, smooth_iters=1, eps=1e-6 349): 350 """Fix edges that cross outside the mesh boundary. 351 352 This function detects skeleton edges that intersect the mesh boundary and 353 fixes them by iteratively splitting crossing edges (inserting new vertices 354 along the edge) and then recentering any vertices that end up outside the 355 mesh using `recenter_vertices()`. 356 357 Notes 358 ----- 359 This will also modify original vertices positions (via `skeletor.post.recenter_vertices()`). 360 Splitting edges inserts new skeleton nodes that are not represented in "skel_map". Currently, 361 we invalidate any existing "mesh_map" and "skel_map" (setting it to `None`). In the 362 future, we may add functionality to update the mapping instead. 363 364 Parameters 365 ---------- 366 s : skeletor.Skeleton 367 mesh : trimesh.Trimesh 368 Original mesh. If mesh is None, will use the mesh associated with input 369 skeleton (`s.mesh`). 370 inplace : bool 371 If False will make and return a copy of the skeleton. If True, 372 will modify the `s` inplace. 373 max_iter : int 374 Max split iterations for boundary-crossing edges. 375 smooth_iters : int 376 Number of smoothing iterations for degree-2 chain nodes. 377 eps : float or {'auto'} 378 Ignore intersections within eps of either endpoint. 379 If "auto", uses mesh mean unique edge length * 1e-4. 380 381 Returns 382 ------- 383 s : skeletor.Skeleton 384 385 """ 386 if isinstance(mesh, type(None)): 387 mesh = s.mesh 388 389 if mesh is None: 390 raise ValueError( 391 "Mesh is required for fixing outside edges. Please provide a mesh or ensure `s.mesh` is set." 392 ) 393 394 if not inplace: 395 s = s.copy() 396 397 if s.swc is None or s.swc.empty: 398 return s 399 400 # Determine eps (scale-aware) 401 if isinstance(eps, str): 402 if eps.lower() != "auto": 403 raise ValueError("Invalid value for `eps`. Must be a number or 'auto'.") 404 try: 405 mean_length = float(np.nanmean(mesh.edges_unique_length)) 406 except Exception: 407 mean_length = np.nan 408 eps = ( 409 mean_length * 1e-4 410 if (np.isfinite(mean_length) and mean_length > 0) 411 else 1e-6 412 ) 413 else: 414 eps = float(eps) 415 416 max_iter = int(max_iter) 417 if max_iter < 0: 418 raise ValueError("`max_iter` must be >= 0") 419 420 smooth_iters = int(smooth_iters) 421 if smooth_iters < 0: 422 raise ValueError("`smooth_iters` must be >= 0") 423 424 coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False) 425 426 # 1. Recenter any nodes outside the mesh 427 if (~coll.contains(s.vertices)).any(): 428 recenter_vertices(s, mesh=mesh, inplace=True) 429 430 has_radius = "radius" in s.swc.columns 431 432 # 2. Iteratively split crossing edges 433 for _ in range(max_iter): 434 swc = s.swc 435 edge_rows = np.where(swc.parent_id.values >= 0)[0] 436 if edge_rows.size == 0: 437 break 438 439 sources = swc.loc[edge_rows, ["x", "y", "z"]].values 440 parent_ids = swc.loc[edge_rows, "parent_id"].values 441 targets = swc.set_index("node_id").loc[parent_ids, ["x", "y", "z"]].values 442 443 ix, loc, _ = coll.intersections(sources, targets) 444 445 crossing = np.zeros(edge_rows.shape[0], dtype=bool) 446 447 if len(ix): 448 d_src = np.linalg.norm(loc - sources[ix], axis=1) 449 d_tgt = np.linalg.norm(loc - targets[ix], axis=1) 450 real_crossings = (d_src > eps) & (d_tgt > eps) 451 if real_crossings.any(): 452 crossing[np.unique(ix[real_crossings])] = True 453 454 if not crossing.any(): 455 break 456 457 to_split = edge_rows[crossing] 458 next_node_id = int(swc.node_id.max()) + 1 459 460 new_rows = [] 461 nodes = swc.set_index("node_id") 462 463 # Cache arrays for cheap positional access inside the loop 464 parent_id_arr = swc["parent_id"].to_numpy(copy=True) 465 parent_col = swc.columns.get_loc("parent_id") 466 467 xyz_arr = swc[["x", "y", "z"]].to_numpy(copy=False) 468 if not np.issubdtype(xyz_arr.dtype, np.number): 469 xyz_arr = xyz_arr.astype(float) 470 471 for edge_row in to_split: 472 # edge_row is a positional row index (from np.where) 473 parent_id = int(parent_id_arr[edge_row]) 474 if parent_id < 0: 475 continue 476 477 # Midpoint (no projection; recenter will handle) 478 child_co = xyz_arr[edge_row].astype(float, copy=False) 479 parent_co = nodes.loc[parent_id, ["x", "y", "z"]].values.astype( 480 float, copy=False 481 ) 482 midpoint = (child_co + parent_co) / 2.0 483 484 # Rewire child -> new node (positional write; avoids boolean mask on node_id) 485 swc.iat[edge_row, parent_col] = next_node_id 486 parent_id_arr[edge_row] = next_node_id # keep cached view consistent 487 488 # Create new node row 489 row = {col: np.nan for col in swc.columns} 490 row["node_id"] = next_node_id 491 row["parent_id"] = parent_id 492 row["x"], row["y"], row["z"] = midpoint 493 494 if has_radius: 495 child_r = pd.to_numeric( 496 pd.Series([swc.at[edge_row, "radius"]]), errors="coerce" 497 ).iloc[0] 498 parent_r = pd.to_numeric( 499 pd.Series([nodes.loc[parent_id, "radius"]]), errors="coerce" 500 ).iloc[0] 501 row["radius"] = np.nanmean(np.array([child_r, parent_r], dtype=float)) 502 503 new_rows.append(row) 504 next_node_id += 1 505 506 if not new_rows: 507 break 508 509 swc = pd.concat( 510 [swc, pd.DataFrame(new_rows, columns=swc.columns)], ignore_index=True 511 ) 512 s.swc = swc 513 514 coords = s.swc[["x", "y", "z"]].values 515 if (~coll.contains(coords)).any(): 516 recenter_vertices(s, mesh=mesh, inplace=True) 517 518 # 3. Smoothing (degree-2 chain nodes), then recenter 519 for _ in range(smooth_iters): 520 swc = s.swc 521 522 child_counts = swc[swc.parent_id >= 0].groupby("parent_id").size() 523 is_chain = (swc.parent_id >= 0) & ( 524 swc.node_id.map(child_counts).fillna(0).astype(int) == 1 525 ) 526 chain_nodes = swc.loc[is_chain, "node_id"].values.astype(int) 527 528 if chain_nodes.size == 0: 529 break 530 531 only_child = ( 532 swc[swc.parent_id >= 0].groupby("parent_id").node_id.first().to_dict() 533 ) 534 nodes = swc.set_index("node_id") 535 536 parent_ids = nodes.loc[chain_nodes, "parent_id"].values.astype(int) 537 child_ids = np.array([only_child[n] for n in chain_nodes], dtype=int) 538 539 parent_co = nodes.loc[parent_ids, ["x", "y", "z"]].values.astype(float) 540 child_co = nodes.loc[child_ids, ["x", "y", "z"]].values.astype(float) 541 smoothed = (parent_co + child_co) / 2.0 542 543 swc.loc[is_chain, ["x", "y", "z"]] = smoothed 544 s.swc = swc 545 546 coords = s.swc[["x", "y", "z"]].values 547 if (~coll.contains(coords)).any(): 548 recenter_vertices(s, mesh=mesh, inplace=True) 549 550 swc = s.swc 551 edge_rows = np.where(swc.parent_id.values >= 0)[0] 552 remaining_crossings = 0 553 # Detect crossing edges again for double-checking 554 if edge_rows.size: 555 sources = swc.loc[edge_rows, ["x", "y", "z"]].values 556 parent_ids = swc.loc[edge_rows, "parent_id"].values 557 nodes = swc.set_index("node_id") 558 try: 559 targets = nodes.loc[parent_ids, ["x", "y", "z"]].values 560 ix, loc, _ = coll.intersections(sources, targets) 561 if len(ix): 562 d_src = np.linalg.norm(loc - sources[ix], axis=1) 563 d_tgt = np.linalg.norm(loc - targets[ix], axis=1) 564 real = (d_src > eps) & (d_tgt > eps) 565 if np.any(real): 566 crossing = np.zeros(edge_rows.shape[0], dtype=bool) 567 crossing[np.unique(ix[real])] = True 568 remaining_crossings = int(crossing.sum()) 569 except KeyError: 570 remaining_crossings = 0 571 572 if remaining_crossings > 0: 573 warnings.warn( 574 f"{remaining_crossings} crossing edges remain after {max_iter} " 575 "fix iteration(s); returning best-effort result. Consider increasing " 576 "`max_iter`, adjusting `eps`, and/or running `post.clean_up` / " 577 "`post.remove_bristles` first. Also check mesh quality (e.g. non-watertight " 578 "or degenerate faces).", 579 RuntimeWarning, 580 ) 581 582 # Invalidate mesh_map 583 s.mesh_map = None 584 585 return s
Fix edges that cross outside the mesh boundary.
This function detects skeleton edges that intersect the mesh boundary and
fixes them by iteratively splitting crossing edges (inserting new vertices
along the edge) and then recentering any vertices that end up outside the
mesh using recenter_vertices().
Notes
This will also modify original vertices positions (via skeletor.post.recenter_vertices()).
Splitting edges inserts new skeleton nodes that are not represented in "skel_map". Currently,
we invalidate any existing "mesh_map" and "skel_map" (setting it to None). In the
future, we may add functionality to update the mapping instead.
Parameters
s (skeletor.Skeleton):
mesh (trimesh.Trimesh): Original mesh. If mesh is None, will use the mesh associated with input skeleton (
s.mesh).- inplace (bool):
If False will make and return a copy of the skeleton. If True,
will modify the
sinplace. - max_iter (int): Max split iterations for boundary-crossing edges.
- smooth_iters (int): Number of smoothing iterations for degree-2 chain nodes.
- eps (float or {'auto'}): Ignore intersections within eps of either endpoint. If "auto", uses mesh mean unique edge length * 1e-4.
Returns
- s (skeletor.Skeleton):