skeletor.post

The skeletor.post module contains functions to post-process skeletons after skeletonization.

Fixing issues with skeletons

Depending on your mesh, pre-processing and the parameters you chose for skeletonization, chances are that your skeleton will not come out perfectly.

clean_up can help you solve some potential issues:

  • skeleton nodes (vertices) that outside or right on the surface instead of centered inside the mesh
  • superfluous "hairs" on otherwise straight bits

smooth will smooth out the skeleton.

despike can help you remove spikes in the skeleton where single nodes are out of aligment.

remove_bristles will remove bristles from the skeleton.

Computing radius information

Only skeletor.skeletonize.by_wavefront() provides radii off the bat. For all other methods, you might want to run radii can help you (re-)generate radius information for the skeletons.

 1#    This script is part of skeletor (http://www.github.com/navis-org/skeletor).
 2#    Copyright (C) 2018 Philipp Schlegel
 3#
 4#    This program is free software: you can redistribute it and/or modify
 5#    it under the terms of the GNU General Public License as published by
 6#    the Free Software Foundation, either version 3 of the License, or
 7#    (at your option) any later version.
 8#
 9#    This program is distributed in the hope that it will be useful,
10#    but WITHOUT ANY WARRANTY; without even the implied warranty of
11#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12#    GNU General Public License for more details.
13#
14#    You should have received a copy of the GNU General Public License
15#    along with this program.
16
17r"""
18The `skeletor.post` module contains functions to post-process skeletons after
19skeletonization.
20
21### Fixing issues with skeletons
22
23Depending on your mesh, pre-processing and the parameters you chose for
24skeletonization, chances are that your skeleton will not come out perfectly.
25
26`skeletor.post.clean_up` can help you solve some potential issues:
27
28- skeleton nodes (vertices) that outside or right on the surface instead of
29  centered inside the mesh
30- superfluous "hairs" on otherwise straight bits
31
32`skeletor.post.smooth` will smooth out the skeleton.
33
34`skeletor.post.despike` can help you remove spikes in the skeleton where
35single nodes are out of aligment.
36
37`skeletor.post.remove_bristles` will remove bristles from the skeleton.
38
39### Computing radius information
40
41Only `skeletor.skeletonize.by_wavefront()` provides radii off the bat. For all
42other methods, you might want to run `skeletor.post.radii` can help you
43(re-)generate radius information for the skeletons.
44
45"""
46
47from .radiusextraction import radii
48from .postprocessing import clean_up, smooth, despike, remove_bristles
49
50__docformat__ = "numpy"
51__all__ = ["radii", "clean_up", "smooth", "despike", "remove_bristles"]
def radii( s, mesh=None, method='knn', aggregate='mean', validate=False, **kwargs):
 29def radii(s, mesh=None, method='knn', aggregate='mean', validate=False, **kwargs):
 30    """Extract radii for given skeleton table.
 31
 32    Important
 33    ---------
 34    This function really only produces useful radii if the skeleton is centered
 35    inside the mesh. `by_wavefront` does that by default whereas all other
 36    skeletonization methods don't. Your best bet to get centered skeletons is
 37    to contract the mesh first (`sk.pre.contract`).
 38
 39    Parameters
 40    ----------
 41    s :         skeletor.Skeleton
 42                Skeleton to clean up.
 43    mesh :      trimesh.Trimesh, optional
 44                Original mesh (e.g. before contraction). If not provided will
 45                use the mesh associated with ``s``.
 46    method :    "knn" | "ray"
 47                Whether and how to add radius information to each node::
 48
 49                    - "knn" uses k-nearest-neighbors to get radii: fast but
 50                      potential for being very wrong
 51                    - "ray" uses ray-casting to get radii: slower but sometimes
 52                      less wrong
 53
 54    aggregate : "mean" | "median" | "max" | "min" | "percentile75"
 55                Function used to aggregate radii over sample (i.e. across
 56                k nearest-neighbors or ray intersections)
 57    validate :  bool
 58                If True, will try to fix potential issues with the mesh
 59                (e.g. infinite values, duplicate vertices, degenerate faces)
 60                before skeletonization. Note that this might make changes to
 61                your mesh inplace!
 62    **kwargs
 63                Keyword arguments are passed to the respective method:
 64
 65                For method "knn"::
 66
 67                    n :             int (default 5)
 68                                    Radius will be the mean over n nearest-neighbors.
 69
 70                For method "ray"::
 71
 72                    n_rays :        int (default 20)
 73                                    Number of rays to cast for each node.
 74                    projection :    "sphere" (default) | "tangents"
 75                                    Whether to cast rays in a sphere around each node or in a
 76                                    circle orthogonally to the node's tangent vector.
 77                    fallback :      "knn" (default) | None | number
 78                                    If a point is outside or right on the surface of the mesh
 79                                    the raycasting will return nonesense results. We can either
 80                                    ignore those cases (``None``), assign a arbitrary number or
 81                                    we can fall back to radii from k-nearest-neighbors (``knn``).
 82
 83    Returns
 84    -------
 85    None
 86                    But attaches `radius` to the skeleton's SWC table. Existing
 87                    values are replaced!
 88
 89    """
 90    if isinstance(mesh, type(None)):
 91        mesh = s.mesh
 92
 93    mesh = make_trimesh(mesh, validate=True)
 94
 95    if method == 'knn':
 96        radius = get_radius_knn(s.swc[['x', 'y', 'z']].values,
 97                                aggregate=aggregate,
 98                                mesh=mesh, **kwargs)
 99    elif method == 'ray':
100        radius = get_radius_ray(s.swc,
101                                mesh=mesh,
102                                aggregate=aggregate,
103                                **kwargs)
104    else:
105        raise ValueError(f'Unknown method "{method}"')
106
107    s.swc['radius'] = radius
108
109    return

Extract radii for given skeleton table.

Important

This function really only produces useful radii if the skeleton is centered inside the mesh. by_wavefront does that by default whereas all other skeletonization methods don't. Your best bet to get centered skeletons is to contract the mesh first (sk.pre.contract).

Parameters
  • s (skeletor.Skeleton): Skeleton to clean up.
  • mesh (trimesh.Trimesh, optional): Original mesh (e.g. before contraction). If not provided will use the mesh associated with s.
  • method ("knn" | "ray"): Whether and how to add radius information to each node::

    - "knn" uses k-nearest-neighbors to get radii: fast but
      potential for being very wrong
    - "ray" uses ray-casting to get radii: slower but sometimes
      less wrong
    
  • aggregate ("mean" | "median" | "max" | "min" | "percentile75"): Function used to aggregate radii over sample (i.e. across k nearest-neighbors or ray intersections)
  • validate (bool): If True, will try to fix potential issues with the mesh (e.g. infinite values, duplicate vertices, degenerate faces) before skeletonization. Note that this might make changes to your mesh inplace!
  • **kwargs: Keyword arguments are passed to the respective method:

For method "knn"::

n :             int (default 5)
                Radius will be the mean over n nearest-neighbors.

For method "ray"::

n_rays :        int (default 20)
                Number of rays to cast for each node.
projection :    "sphere" (default) | "tangents"
                Whether to cast rays in a sphere around each node or in a
                circle orthogonally to the node's tangent vector.
fallback :      "knn" (default) | None | number
                If a point is outside or right on the surface of the mesh
                the raycasting will return nonesense results. We can either
                ignore those cases (``None``), assign a arbitrary number or
                we can fall back to radii from k-nearest-neighbors (``knn``).
Returns
  • None: But attaches radius to the skeleton's SWC table. Existing values are replaced!
def clean_up(s, mesh=None, validate=False, inplace=False, **kwargs):
29def clean_up(s, mesh=None, validate=False, inplace=False, **kwargs):
30    """Clean up the skeleton.
31
32    This function bundles a bunch of procedures to clean up the skeleton:
33
34      1. Remove twigs that are running parallel to their parent branch
35      2. Move nodes outside the mesh back inside (or at least snap to surface)
36
37    Note that this is not a magic bullet and some of this will not work (well)
38    if the original mesh was degenerate (e.g. internal faces or not watertight)
39    to begin with.
40
41    Parameters
42    ----------
43    s :         skeletor.Skeleton
44                Skeleton to clean up.
45    mesh :      trimesh.Trimesh, optional
46                Original mesh (e.g. before contraction). If not provided will
47                use the mesh associated with ``s``.
48    validate :  bool
49                If True, will try to fix potential issues with the mesh
50                (e.g. infinite values, duplicate vertices, degenerate faces)
51                before cleaning up. Note that this might change your mesh
52                inplace!
53    inplace :   bool
54                If False will make and return a copy of the skeleton. If True,
55                will modify the `s` inplace.
56
57    **kwargs
58                Keyword arguments are passed to the bundled function.
59
60                For `skeletor.postprocessing.drop_parallel_twigs`::
61
62                 theta :     float (default 0.01)
63                             For each twig we generate the dotproduct between the tangent
64                             vectors of it and its parents. If these line up perfectly the
65                             dotproduct will equal 1. ``theta`` determines how much that
66                             value can differ from 1 for us to still prune the twig: higher
67                             theta = more pruning.
68
69    Returns
70    -------
71    s_clean :   skeletor.Skeleton
72                Hopefully improved skeleton.
73
74    """
75    if isinstance(mesh, type(None)):
76        mesh = s.mesh
77
78    mesh = make_trimesh(mesh, validate=validate)
79
80    if not inplace:
81        s = s.copy()
82
83    # Drop parallel twigs
84    _ = drop_parallel_twigs(s, theta=kwargs.get('theta', 0.01), inplace=True)
85
86    # Recenter vertices
87    _ = recenter_vertices(s, mesh, inplace=True)
88
89    return s

Clean up the skeleton.

This function bundles a bunch of procedures to clean up the skeleton:

  1. Remove twigs that are running parallel to their parent branch
  2. Move nodes outside the mesh back inside (or at least snap to surface)

Note that this is not a magic bullet and some of this will not work (well) if the original mesh was degenerate (e.g. internal faces or not watertight) to begin with.

Parameters
  • s (skeletor.Skeleton): Skeleton to clean up.
  • mesh (trimesh.Trimesh, optional): Original mesh (e.g. before contraction). If not provided will use the mesh associated with s.
  • validate (bool): If True, will try to fix potential issues with the mesh (e.g. infinite values, duplicate vertices, degenerate faces) before cleaning up. Note that this might change your mesh inplace!
  • inplace (bool): If False will make and return a copy of the skeleton. If True, will modify the s inplace.
  • **kwargs: Keyword arguments are passed to the bundled function.

For skeletor.postprocessing.drop_parallel_twigs::

theta : float (default 0.01) For each twig we generate the dotproduct between the tangent vectors of it and its parents. If these line up perfectly the dotproduct will equal 1. theta determines how much that value can differ from 1 for us to still prune the twig: higher theta = more pruning.

Returns
def smooth( s, window: int = 3, to_smooth: list = ['x', 'y', 'z'], inplace: bool = False):
554def smooth(s,
555           window: int = 3,
556           to_smooth: list = ['x', 'y', 'z'],
557           inplace: bool = False):
558    """Smooth skeleton using rolling windows.
559
560    Parameters
561    ----------
562    s :             skeletor.Skeleton
563                    Skeleton to be processed.
564    window :        int, optional
565                    Size (N observations) of the rolling window in number of
566                    nodes.
567    to_smooth :     list
568                    Columns of the node table to smooth. Should work with any
569                    numeric column (e.g. 'radius').
570    inplace :       bool
571                    If False will make and return a copy of the skeleton. If
572                    True, will modify the `s` inplace.
573
574    Returns
575    -------
576    s :             skeletor.Skeleton
577                    Skeleton with smoothed node table.
578
579    """
580    if not inplace:
581        s = s.copy()
582
583    # Prepare nodes (add parent_dist for later, set index)
584    nodes = s.swc.set_index('node_id', inplace=False).copy()
585
586    to_smooth = np.array(to_smooth)
587    miss = to_smooth[~np.isin(to_smooth, nodes.columns)]
588    if len(miss):
589        raise ValueError(f'Column(s) not found in node table: {miss}')
590
591    # Go over each segment and smooth
592    for seg in s.get_segments():
593        # Get this segment's parent distances and get cumsum
594        this_co = nodes.loc[seg, to_smooth]
595
596        interp = this_co.rolling(window, min_periods=1).mean()
597
598        nodes.loc[seg, to_smooth] = interp.values
599
600    # Reassign nodes
601    s.swc = nodes.reset_index(drop=False, inplace=False)
602
603    return s

Smooth skeleton using rolling windows.

Parameters
  • s (skeletor.Skeleton): Skeleton to be processed.
  • window (int, optional): Size (N observations) of the rolling window in number of nodes.
  • to_smooth (list): Columns of the node table to smooth. Should work with any numeric column (e.g. 'radius').
  • inplace (bool): If False will make and return a copy of the skeleton. If True, will modify the s inplace.
Returns
def despike(s, sigma=5, max_spike_length=1, inplace=False, reverse=False):
605def despike(s,
606            sigma = 5,
607            max_spike_length = 1,
608            inplace = False,
609            reverse = False):
610    r"""Remove spikes in skeleton.
611
612    For each node A, the euclidean distance to its next successor (parent)
613    B and that node's successor C (i.e A->B->C) is computed. If
614    :math:`\\frac{dist(A,B)}{dist(A,C)}>sigma`, node B is considered a spike
615    and realigned between A and C.
616
617    Parameters
618    ----------
619    x :                 skeletor.Skeleton
620                        Skeleton to be processed.
621    sigma :             float | int, optional
622                        Threshold for spike detection. Smaller sigma = more
623                        aggressive spike detection.
624    max_spike_length :  int, optional
625                        Determines how long (# of nodes) a spike can be.
626    inplace :           bool, optional
627                        If False, a copy of the neuron is returned.
628    reverse :           bool, optional
629                        If True, will **also** walk the segments from proximal
630                        to distal. Use this to catch spikes on e.g. terminal
631                        nodes.
632
633    Returns
634    -------
635    s                   skeletor.Skeleton
636                        Despiked neuron.
637
638    """
639    if not inplace:
640        s = s.copy()
641
642    # Index nodes table by node ID
643    this_nodes = s.swc.set_index('node_id', inplace=False)
644
645    segments = s.get_segments()
646    segs_to_walk = segments
647
648    if reverse:
649        segs_to_walk += segs_to_walk[::-1]
650
651    # For each spike length do -> do this in reverse to correct the long
652    # spikes first
653    for l in list(range(1, max_spike_length + 1))[::-1]:
654        # Go over all segments
655        for seg in segs_to_walk:
656            # Get nodes A, B and C of this segment
657            this_A = this_nodes.loc[seg[:-l - 1]]
658            this_B = this_nodes.loc[seg[l:-1]]
659            this_C = this_nodes.loc[seg[l + 1:]]
660
661            # Get coordinates
662            A = this_A[['x', 'y', 'z']].values
663            B = this_B[['x', 'y', 'z']].values
664            C = this_C[['x', 'y', 'z']].values
665
666            # Calculate euclidian distances A->B and A->C
667            dist_AB = np.linalg.norm(A - B, axis=1)
668            dist_AC = np.linalg.norm(A - C, axis=1)
669
670            # Get the spikes
671            spikes_ix = np.where(np.divide(dist_AB, dist_AC, where=dist_AC != 0) > sigma)[0]
672            spikes = this_B.iloc[spikes_ix]
673
674            if not spikes.empty:
675                # Interpolate new position(s) between A and C
676                new_positions = A[spikes_ix] + (C[spikes_ix] - A[spikes_ix]) / 2
677
678                this_nodes.loc[spikes.index, ['x', 'y', 'z']] = new_positions
679
680    # Reassign node table
681    s.swc = this_nodes.reset_index(drop=False, inplace=False)
682
683    return s

Remove spikes in skeleton.

For each node A, the euclidean distance to its next successor (parent) B and that node's successor C (i.e A->B->C) is computed. If \( \frac{dist(A,B)}{dist(A,C)}>sigma \), node B is considered a spike and realigned between A and C.

Parameters
  • x (skeletor.Skeleton): Skeleton to be processed.
  • sigma (float | int, optional): Threshold for spike detection. Smaller sigma = more aggressive spike detection.
  • max_spike_length (int, optional): Determines how long (# of nodes) a spike can be.
  • inplace (bool, optional): If False, a copy of the neuron is returned.
  • reverse (bool, optional): If True, will also walk the segments from proximal to distal. Use this to catch spikes on e.g. terminal nodes.
Returns
def remove_bristles(s, mesh=None, los_only=False, inplace=False):
 92def remove_bristles(s, mesh=None, los_only=False, inplace=False):
 93    """Remove "bristles" that sometimes occurr along the backbone.
 94
 95    Works by finding terminal twigs that consist of only a single node.
 96
 97    Parameters
 98    ----------
 99    s :         skeletor.Skeleton
100                Skeleton to clean up.
101    mesh :      trimesh.Trimesh, optional
102                Original mesh (e.g. before contraction). If not provided will
103                use the mesh associated with ``s``.
104    los_only :  bool
105                If True, will only remove bristles that are in line of sight of
106                their parent. If False, will remove all single-node bristles.
107    inplace :   bool
108                If False will make and return a copy of the skeleton. If True,
109                will modify the `s` inplace.
110
111    Returns
112    -------
113    s :         skeletor.Skeleton
114                Skeleton with single-node twigs removed.
115
116    """
117    if isinstance(mesh, type(None)):
118        mesh = s.mesh
119
120    # Make a copy of the skeleton
121    if not inplace:
122        s = s.copy()
123
124    # Find branch points
125    pcount = s.swc[s.swc.parent_id >= 0].groupby('parent_id').size()
126    bp = pcount[pcount > 1].index
127
128    # Find terminal twigs
129    twigs = s.swc[~s.swc.node_id.isin(s.swc.parent_id)]
130    twigs = twigs[twigs.parent_id.isin(bp)]
131
132    if twigs.empty:
133        return s
134
135    if los_only:
136        # Initialize ncollpyde Volume
137        coll = ncollpyde.Volume(mesh.vertices, mesh.faces, validate=False)
138
139        # Remove twigs that aren't inside the volume
140        twigs = twigs[coll.contains(twigs[['x', 'y', 'z']].values)]
141
142        # Generate rays between all pairs and their parents
143        sources = twigs[['x', 'y', 'z']].values
144        targets = s.swc.set_index('node_id').loc[twigs.parent_id,
145                                                ['x', 'y', 'z']].values
146
147        # Get intersections: `ix` points to index of line segment; `loc` is the
148        #  x/y/z coordinate of the intersection and `is_backface` is True if
149        # intersection happened at the inside of a mesh
150        ix, loc, is_backface = coll.intersections(sources, targets)
151
152        # Find pairs of twigs with no intersection - i.e. with line of sight
153        los = ~np.isin(np.arange(sources.shape[0]), ix)
154
155        # To remove: have line of sight
156        to_remove = twigs[los]
157    else:
158        to_remove = twigs
159
160    s.swc = s.swc[~s.swc.node_id.isin(to_remove.node_id)].copy()
161
162    # Update the mesh map
163    mesh_map = getattr(s, 'mesh_map', None)
164    if not isinstance(mesh_map, type(None)):
165        for t in to_remove.itertuples():
166            mesh_map[mesh_map == t.node_id] = t.parent_id
167
168    # Reindex nodes
169    s.reindex(inplace=True)
170
171    return s

Remove "bristles" that sometimes occurr along the backbone.

Works by finding terminal twigs that consist of only a single node.

Parameters
  • s (skeletor.Skeleton): Skeleton to clean up.
  • mesh (trimesh.Trimesh, optional): Original mesh (e.g. before contraction). If not provided will use the mesh associated with s.
  • los_only (bool): If True, will only remove bristles that are in line of sight of their parent. If False, will remove all single-node bristles.
  • inplace (bool): If False will make and return a copy of the skeleton. If True, will modify the s inplace.
Returns