Skip to content

smat

Class converting continuous values into discrete indices.

PARAMETER DESCRIPTION
boundaries

N boundaries specifying N-1 bins. Must be monotonically increasing.

TYPE: Sequence[float]

clip

Whether to set the bottom and top boundaries to -infinity and infinity respectively, effectively clipping incoming values: by default (True, True). False means "add a new bin for out-of-range values".

TYPE: Tuple[bool, bool] DEFAULT: (True, True)

right

Whether bins should include their right (rather than left) boundary, by default False.

TYPE: bool DEFAULT: False

Source code in navis/nbl/smat.py
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
class Digitizer(LookupAxis[float]):
    """Class converting continuous values into discrete indices.

    Parameters
    ----------
    boundaries : Sequence[float]
        N boundaries specifying N-1 bins.
        Must be monotonically increasing.
    clip : Tuple[bool, bool], optional
        Whether to set the bottom and top boundaries to -infinity and
        infinity respectively, effectively clipping incoming values: by
        default (True, True).
        False means "add a new bin for out-of-range values".
    right : bool, optional
        Whether bins should include their right (rather than left) boundary,
        by default False.
    """

    def __init__(
        self,
        boundaries: Sequence[float],
        clip: Tuple[bool, bool] = (True, True),
        right=False,
    ):
        self.right = right

        boundaries = list(boundaries)
        self._min = -math.inf
        if clip[0]:
            self._min = boundaries[0]
            boundaries[0] = -math.inf
        elif boundaries[0] != -math.inf:
            self._min = -math.inf
            boundaries.insert(0, -math.inf)

        self._max = math.inf
        if clip[1]:
            self._max = boundaries[-1]
            boundaries[-1] = math.inf
        elif boundaries[-1] != math.inf:
            boundaries.append(math.inf)

        if not is_monotonically_increasing(boundaries):
            raise ValueError(
                "Boundaries are not monotonically increasing: " f"{boundaries}"
            )

        self.boundaries = np.asarray(boundaries)

    def __len__(self):
        return len(self.boundaries) - 1

    def __call__(self, value: float):
        # searchsorted is marginally faster than digitize as it skips monotonicity checks
        return (
            np.searchsorted(
                self.boundaries, value, side="left" if self.right else "right"
            )
            - 1
        )

    def to_strings(self, round=None) -> List[str]:
        """Turn boundaries into list of labels.

        Parameters
        ----------
        round :     int, optional
                    Use to round bounds to the Nth decimal.
        """
        if self.right:
            lb = "("
            rb = "]"
        else:
            lb = "["
            rb = ")"

        b = self.boundaries.copy()
        b[0] = self._min
        b[-1] = self._max

        if round:
            b = [np.round(x, round) for x in b]

        return [f"{lb}{lower},{upper}{rb}" for lower, upper in zip(b[:-1], b[1:])]

    @classmethod
    def from_strings(cls, interval_strs: Sequence[str]):
        """Set digitizer boundaries based on a sequence of interval expressions.

        e.g. `["(0, 1]", "(1, 5]", "(5, 10]"]`

        The lowermost and uppermost boundaries are converted to -infinity and
        infinity respectively.

        Parameters
        ----------
        bound_strs : Sequence[str]
            Strings representing intervals, which must abut and have open/closed
            boundaries specified by brackets.

        Returns
        -------
        Digitizer
        """
        bounds: List[float] = []
        last_upper = None
        last_right = None
        for item in interval_strs:
            (lower, upper), right = parse_boundary(item)
            bounds.append(float(lower))

            if last_right is not None:
                if right != last_right:
                    raise ValueError("Inconsistent half-open interval")
            else:
                last_right = right

            if last_upper is not None:
                if lower != last_upper:
                    raise ValueError("Half-open intervals do not abut")

            last_upper = upper

        bounds.append(float(last_upper))
        return cls(bounds, right=last_right)

    @classmethod
    def from_linear(cls, lower: float, upper: float, nbins: int, right=False):
        """Choose digitizer boundaries spaced linearly between two values.

        Input values will be clipped to fit within the given interval.

        Parameters
        ----------
        lower : float
            Lowest value
        upper : float
            Highest value
        nbins : int
            Number of bins
        right : bool, optional
            Whether bins should include their right (rather than left) boundary,
            by default False

        Returns
        -------
        Digitizer
        """
        arr = np.linspace(lower, upper, nbins + 1, endpoint=True)
        return cls(arr, right=right)

    @classmethod
    def from_geom(
        cls, lowest_upper: float, highest_lower: float, nbins: int, right=False
    ):
        """Choose digitizer boundaries in a geometric sequence.

        Additional bins will be added above and below the given values.

        Parameters
        ----------
        lowest_upper : float
            Upper bound of the lowest bin. The lower bound of the lowest bin is
            often 0, which cannot be represented in a nontrivial geometric
            sequence.
        highest_lower : float
            Lower bound of the highest bin.
        nbins : int
            Number of bins
        right : bool, optional
            Whether bins should include their right (rather than left) boundary,
            by default False

        Returns
        -------
        Digitizer
        """
        arr = np.geomspace(lowest_upper, highest_lower, nbins - 1, True)
        return cls(arr, clip=(False, False), right=right)

    @classmethod
    def from_data(
        cls, data: Sequence[float], nbins: int, right=False, method="quantile"
    ):
        """Choose digitizer boundaries to evenly partition the given values.

        Parameters
        ----------
        data : Sequence[float]
            Data which should be partitioned by the resulting digitizer.
        nbins : int
            Number of bins
        right : bool, optional
            Whether bins should include their right (rather than left) boundary,
            by default False
        method : "quantile" | "linear" | "geometric"
            Method to use for partitioning the data space:
             - 'quantile' (default) will partition the data such that each bin
               contains the same number of data points. This is usually the
               method of choice because it is robust against outlier and because
               we are guaranteed to not have empty bin.
             - 'linear' will partition the data into evenly spaced bins.
             - 'geometric' will produce a log scale partition. This will not work
               if data has negative values.

        Returns
        -------
        Digitizer
        """
        assert method in ("quantile", "linear", "geometric")

        if method == "quantile":
            arr = np.quantile(data, np.linspace(0, 1, nbins + 1, True))
        elif method == "linear":
            arr = np.linspace(min(data), max(data), nbins + 1, True)
        elif method == "geometric":
            if min(data) <= 0:
                raise ValueError(
                    "Data must not have values <= 0 for creating "
                    "geometric (logarithmic) bins."
                )
            arr = np.geomspace(min(data), max(data), nbins + 1, True)
        return cls(arr, right=right)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Digitizer):
            return NotImplemented
        return self.right == other.right and np.allclose(
            self.boundaries, other.boundaries
        )

Choose digitizer boundaries to evenly partition the given values.

PARAMETER DESCRIPTION
data

Data which should be partitioned by the resulting digitizer.

TYPE: Sequence[float]

nbins

Number of bins

TYPE: int

right

Whether bins should include their right (rather than left) boundary, by default False

TYPE: bool DEFAULT: False

method

Method to use for partitioning the data space: - 'quantile' (default) will partition the data such that each bin contains the same number of data points. This is usually the method of choice because it is robust against outlier and because we are guaranteed to not have empty bin. - 'linear' will partition the data into evenly spaced bins. - 'geometric' will produce a log scale partition. This will not work if data has negative values.

TYPE: 'quantile' | 'linear' | 'geometric' DEFAULT: 'quantile'

RETURNS DESCRIPTION
Digitizer
Source code in navis/nbl/smat.py
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
@classmethod
def from_data(
    cls, data: Sequence[float], nbins: int, right=False, method="quantile"
):
    """Choose digitizer boundaries to evenly partition the given values.

    Parameters
    ----------
    data : Sequence[float]
        Data which should be partitioned by the resulting digitizer.
    nbins : int
        Number of bins
    right : bool, optional
        Whether bins should include their right (rather than left) boundary,
        by default False
    method : "quantile" | "linear" | "geometric"
        Method to use for partitioning the data space:
         - 'quantile' (default) will partition the data such that each bin
           contains the same number of data points. This is usually the
           method of choice because it is robust against outlier and because
           we are guaranteed to not have empty bin.
         - 'linear' will partition the data into evenly spaced bins.
         - 'geometric' will produce a log scale partition. This will not work
           if data has negative values.

    Returns
    -------
    Digitizer
    """
    assert method in ("quantile", "linear", "geometric")

    if method == "quantile":
        arr = np.quantile(data, np.linspace(0, 1, nbins + 1, True))
    elif method == "linear":
        arr = np.linspace(min(data), max(data), nbins + 1, True)
    elif method == "geometric":
        if min(data) <= 0:
            raise ValueError(
                "Data must not have values <= 0 for creating "
                "geometric (logarithmic) bins."
            )
        arr = np.geomspace(min(data), max(data), nbins + 1, True)
    return cls(arr, right=right)

Choose digitizer boundaries in a geometric sequence.

Additional bins will be added above and below the given values.

PARAMETER DESCRIPTION
lowest_upper

Upper bound of the lowest bin. The lower bound of the lowest bin is often 0, which cannot be represented in a nontrivial geometric sequence.

TYPE: float

highest_lower

Lower bound of the highest bin.

TYPE: float

nbins

Number of bins

TYPE: int

right

Whether bins should include their right (rather than left) boundary, by default False

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Digitizer
Source code in navis/nbl/smat.py
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
@classmethod
def from_geom(
    cls, lowest_upper: float, highest_lower: float, nbins: int, right=False
):
    """Choose digitizer boundaries in a geometric sequence.

    Additional bins will be added above and below the given values.

    Parameters
    ----------
    lowest_upper : float
        Upper bound of the lowest bin. The lower bound of the lowest bin is
        often 0, which cannot be represented in a nontrivial geometric
        sequence.
    highest_lower : float
        Lower bound of the highest bin.
    nbins : int
        Number of bins
    right : bool, optional
        Whether bins should include their right (rather than left) boundary,
        by default False

    Returns
    -------
    Digitizer
    """
    arr = np.geomspace(lowest_upper, highest_lower, nbins - 1, True)
    return cls(arr, clip=(False, False), right=right)

Choose digitizer boundaries spaced linearly between two values.

Input values will be clipped to fit within the given interval.

PARAMETER DESCRIPTION
lower

Lowest value

TYPE: float

upper

Highest value

TYPE: float

nbins

Number of bins

TYPE: int

right

Whether bins should include their right (rather than left) boundary, by default False

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Digitizer
Source code in navis/nbl/smat.py
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
@classmethod
def from_linear(cls, lower: float, upper: float, nbins: int, right=False):
    """Choose digitizer boundaries spaced linearly between two values.

    Input values will be clipped to fit within the given interval.

    Parameters
    ----------
    lower : float
        Lowest value
    upper : float
        Highest value
    nbins : int
        Number of bins
    right : bool, optional
        Whether bins should include their right (rather than left) boundary,
        by default False

    Returns
    -------
    Digitizer
    """
    arr = np.linspace(lower, upper, nbins + 1, endpoint=True)
    return cls(arr, right=right)

Set digitizer boundaries based on a sequence of interval expressions.

e.g. ["(0, 1]", "(1, 5]", "(5, 10]"]

The lowermost and uppermost boundaries are converted to -infinity and infinity respectively.

PARAMETER DESCRIPTION
bound_strs

Strings representing intervals, which must abut and have open/closed boundaries specified by brackets.

TYPE: Sequence[str]

RETURNS DESCRIPTION
Digitizer
Source code in navis/nbl/smat.py
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
@classmethod
def from_strings(cls, interval_strs: Sequence[str]):
    """Set digitizer boundaries based on a sequence of interval expressions.

    e.g. `["(0, 1]", "(1, 5]", "(5, 10]"]`

    The lowermost and uppermost boundaries are converted to -infinity and
    infinity respectively.

    Parameters
    ----------
    bound_strs : Sequence[str]
        Strings representing intervals, which must abut and have open/closed
        boundaries specified by brackets.

    Returns
    -------
    Digitizer
    """
    bounds: List[float] = []
    last_upper = None
    last_right = None
    for item in interval_strs:
        (lower, upper), right = parse_boundary(item)
        bounds.append(float(lower))

        if last_right is not None:
            if right != last_right:
                raise ValueError("Inconsistent half-open interval")
        else:
            last_right = right

        if last_upper is not None:
            if lower != last_upper:
                raise ValueError("Half-open intervals do not abut")

        last_upper = upper

    bounds.append(float(last_upper))
    return cls(bounds, right=last_right)

Turn boundaries into list of labels.

PARAMETER DESCRIPTION
round
    Use to round bounds to the Nth decimal.

TYPE: int DEFAULT: None

Source code in navis/nbl/smat.py
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
def to_strings(self, round=None) -> List[str]:
    """Turn boundaries into list of labels.

    Parameters
    ----------
    round :     int, optional
                Use to round bounds to the Nth decimal.
    """
    if self.right:
        lb = "("
        rb = "]"
    else:
        lb = "["
        rb = ")"

    b = self.boundaries.copy()
    b[0] = self._min
    b[-1] = self._max

    if round:
        b = [np.round(x, round) for x in b]

    return [f"{lb}{lower},{upper}{rb}" for lower, upper in zip(b[:-1], b[1:])]

Class converting some data into a linear index.

Source code in navis/nbl/smat.py
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
class LookupAxis(ABC, Generic[T]):
    """Class converting some data into a linear index."""

    @abstractmethod
    def __len__(self) -> int:
        """Number of bins represented by this instance."""
        pass

    @abstractmethod
    def __call__(self, value: Union[T, Sequence[T]]) -> Union[int, Sequence[int]]:
        """Convert some data into a linear index.

        Parameters
        ----------
        value : Union[T, Sequence[T]]
            Value to convert into an index

        Returns
        -------
        Union[int, Sequence[int]]
            If a scalar was given, return a scalar; otherwise, a numpy array of ints.
        """
        pass

Class for building a 2-dimensional score lookup for NBLAST.

The scores are

  1. The distances between best-matching points
  2. The dot products of direction vectors around those points, optionally scaled by the colinearity alpha.
PARAMETER DESCRIPTION
dotprops
            An indexable sequence of all neurons which will be
            used as the training set, as Dotprops objects.

TYPE: dict or list of Dotprops

matching_lists
            List of neurons, as indices into `dotprops`, which
            should be considered matches.

TYPE: list of lists of indices into dotprops

nonmatching_list
            List of neurons, as indices into `dotprops`,
            which should not be considered matches.
            If not given, all `dotprops` will be used
            (on the assumption that matches are a small subset
            of possible pairs).

TYPE: list of indices into dotprops DEFAULT: None

use_alpha
            If true, multiply the dot product by the geometric
            mean of the matched points' alpha values
            (i.e. `sqrt(alpha1 * alpha2)`).

TYPE: bool DEFAULT: False

draw_strat
        Strategy for randomly drawing non-matching pairs.
        "batched" should be the right choice in most scenarios.
        "greedy" can be better if your pool of neurons is very
        small.

TYPE: "batched" | "greedy" DEFAULT: 'batched'

seed
            Non-matching pairs are drawn at random using this
            seed, by default {DEFAULT_SEED}.

TYPE: int DEFAULT: DEFAULT_SEED

Source code in navis/nbl/smat.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
class LookupDistDotBuilder(LookupNdBuilder):
    """Class for building a 2-dimensional score lookup for NBLAST.

    The scores are

    1. The distances between best-matching points
    2. The dot products of direction vectors around those points,
        optionally scaled by the colinearity `alpha`.

    Parameters
    ----------
    dotprops :          dict or list of Dotprops
                        An indexable sequence of all neurons which will be
                        used as the training set, as Dotprops objects.
    matching_lists :    list of lists of indices into dotprops
                        List of neurons, as indices into `dotprops`, which
                        should be considered matches.
    nonmatching_list :  list of indices into dotprops, optional
                        List of neurons, as indices into `dotprops`,
                        which should not be considered matches.
                        If not given, all `dotprops` will be used
                        (on the assumption that matches are a small subset
                        of possible pairs).
    use_alpha :         bool, optional
                        If true, multiply the dot product by the geometric
                        mean of the matched points' alpha values
                        (i.e. `sqrt(alpha1 * alpha2)`).
    draw_strat :    "batched" | "greedy"
                    Strategy for randomly drawing non-matching pairs.
                    "batched" should be the right choice in most scenarios.
                    "greedy" can be better if your pool of neurons is very
                    small.
    seed :              int, optional
                        Non-matching pairs are drawn at random using this
                        seed, by default {DEFAULT_SEED}.
    """

    def __init__(
        self,
        dotprops: Union[List["core.Dotprops"], Mapping[NeuronKey, "core.Dotprops"]],
        matching_lists: List[List[NeuronKey]],
        nonmatching_list: Optional[List[NeuronKey]] = None,
        use_alpha: bool = False,
        draw_strat: str = "batched",
        seed: int = DEFAULT_SEED,
    ):
        match_fn = dist_dot_alpha if use_alpha else dist_dot
        super().__init__(
            dotprops,
            matching_lists,
            match_fn,
            nonmatching_list,
            draw_strat=draw_strat,
            seed=seed,
        )
        self._ndim = 2

    def build(self, threads=None) -> Lookup2d:
        (dig0, dig1), cells = self._build(threads)
        return Lookup2d(dig0, dig1, cells)
Source code in navis/nbl/smat.py
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
class LookupNd:
    def __init__(self, axes: List[LookupAxis], cells: np.ndarray):
        if [len(b) for b in axes] != list(cells.shape):
            raise ValueError("boundaries and cells have inconsistent bin counts")
        self.axes = axes
        self.cells = cells

    def __call__(self, *args):
        if len(args) != len(self.axes):
            raise TypeError(
                f"Lookup takes {len(self.axes)} arguments but {len(args)} were given"
            )

        idxs = tuple(d(arg) for d, arg in zip(self.axes, args))
        out = self.cells[idxs]
        return out

Class for building an N-dimensional score lookup (for e.g. NBLAST).

Once instantiated, the axes of the lookup table must be defined. Call .with_digitizers() to manually define them, or .with_bin_counts() to learn them from the matched-pair data.

Then call .build() to build the lookup table.

PARAMETER DESCRIPTION
neurons
        An indexable, consistently-ordered sequence of all
        objects (typically neurons) which will be used as the
        training set. Importantly: each object must have a
        `len()`!

TYPE: dict or list of objects (e.g. Dotprops)

matching_sets
        Lists of neurons, as indices into `neurons`, which
        should be considered matches:

            [[0, 1, 2, 4], [5, 6], [9, 10, 11]]

TYPE: list of lists of index into `neurons`

match_fn
        Function taking 2 arguments, both instances of type
        `neurons`, and returning a list of 1D
        `numpy.ndarray`s of floats. The length of the list
        must be the same as the length of `boundaries`.
        The length of the `array`s must be the same as the
        number of points in the first argument. This function
        returns values describing the quality of point matches
        from a query to a target neuron.

TYPE: Callable[[object, object], List[np.ndarray[float]]]

nonmatching
        List of objects, as indices into `neurons`, which
        should be be considered NON-matches. If not given,
        all `neurons` will be used (on the assumption that
        matches are a small subset of possible pairs).

TYPE: list of index into `neurons`

draw_strat
        Strategy for randomly drawing non-matching pairs. Only
        relevant if `nonmatching` is not provided.
        "batched" should be the right choice in most scenarios.
        "greedy" can be better if your pool of neurons is very
        small.

TYPE: "batched" | "greedy" DEFAULT: 'batched'

seed
        Non-matching pairs are drawn at random using this seed,
        by default {DEFAULT_SEED}.

TYPE: int DEFAULT: DEFAULT_SEED

Source code in navis/nbl/smat.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
class LookupNdBuilder:
    """Class for building an N-dimensional score lookup (for e.g. NBLAST).

    Once instantiated, the axes of the lookup table must be defined.
    Call `.with_digitizers()` to manually define them, or
    `.with_bin_counts()` to learn them from the matched-pair data.

    Then call `.build()` to build the lookup table.

    Parameters
    ----------
    neurons :       dict or list of objects (e.g. Dotprops)
                    An indexable, consistently-ordered sequence of all
                    objects (typically neurons) which will be used as the
                    training set. Importantly: each object must have a
                    `len()`!
    matching_sets : list of lists of index into `neurons`
                    Lists of neurons, as indices into `neurons`, which
                    should be considered matches:

                        [[0, 1, 2, 4], [5, 6], [9, 10, 11]]

    match_fn :      Callable[[object, object], List[np.ndarray[float]]]
                    Function taking 2 arguments, both instances of type
                    `neurons`, and returning a list of 1D
                    `numpy.ndarray`s of floats. The length of the list
                    must be the same as the length of `boundaries`.
                    The length of the `array`s must be the same as the
                    number of points in the first argument. This function
                    returns values describing the quality of point matches
                    from a query to a target neuron.
    nonmatching :   list of index into `neurons`, optional
                    List of objects, as indices into `neurons`, which
                    should be be considered NON-matches. If not given,
                    all `neurons` will be used (on the assumption that
                    matches are a small subset of possible pairs).
    draw_strat :    "batched" | "greedy"
                    Strategy for randomly drawing non-matching pairs. Only
                    relevant if `nonmatching` is not provided.
                    "batched" should be the right choice in most scenarios.
                    "greedy" can be better if your pool of neurons is very
                    small.
    seed :          int, optional
                    Non-matching pairs are drawn at random using this seed,
                    by default {DEFAULT_SEED}.

    """

    def __init__(
        self,
        neurons: Union[List[T], Mapping[NeuronKey, T]],
        matching_lists: List[List[NeuronKey]],
        match_fn: Callable[[T, T], List[np.ndarray]],
        nonmatching_list: Optional[List[NeuronKey]] = None,
        draw_strat: str = "batched",
        seed: int = DEFAULT_SEED,
    ) -> None:
        self.objects = neurons
        self.matching_lists = matching_lists
        self._nonmatching_list = nonmatching_list
        self.match_fn = match_fn
        self.nonmatching_draw = draw_strat

        self.digitizers: Optional[List[Digitizer]] = None
        self.bin_counts: Optional[List[int]] = None

        self.seed = seed
        self._ndim: Optional[int] = None

    @property
    def ndim(self) -> int:
        if self._ndim is None:
            idx1, idx2 = self._object_keys()[:2]
            self._ndim = len(self._query(idx1, idx2))
        return self._ndim

    def with_digitizers(self, digitizers: List[Digitizer]):
        """Specify the axes of the output lookup table directly.

        Parameters
        ----------
        digitizers : List[Digitizer]

        Returns
        -------
        self
            For chaining convenience.
        """
        if len(digitizers) != self.ndim:
            raise ValueError(
                f"Match function returns {self.ndim} values "
                f"but provided {len(digitizers)} digitizers"
            )

        self.digitizers = digitizers
        self.bin_counts = None
        return self

    def with_bin_counts(self, bin_counts: List[int], method="quantile"):
        """Specify the number of bins on each axis of the output lookup table.

        The bin boundaries will be determined by evenly partitioning the data
        from the matched pairs into quantiles, in each dimension.

        Parameters
        ----------
        bin_counts : List[int]
        method :     'quantile' | 'geometric' | 'linear'
                     Method used to tile the data space.

        Returns
        -------
        self
            For chaining convenience.
        """
        if len(bin_counts) != self.ndim:
            raise ValueError(
                f"Match function returns {self.ndim} values "
                f"but provided {len(bin_counts)} bin counts"
            )

        self.bin_counts = bin_counts
        self.digitizers = None
        self.bin_method = method
        return self

    def _object_keys(self) -> Sequence[NeuronKey]:
        """Get all indices into objects instance member."""
        try:
            return self.objects.keys()
        except AttributeError:
            return range(len(self.objects))

    @property
    def nonmatching(self) -> List[NeuronKey]:
        """Indices of nonmatching set of neurons."""
        if self._nonmatching_list is None:
            return list(self._object_keys())
        return self._nonmatching_list

    def _yield_matching_pairs(self) -> Iterator[Tuple[NeuronKey, NeuronKey]]:
        """Yield all index pairs within all matching pairs."""
        for ms in self.matching_lists:
            yield from yield_not_same(permutations(ms, 2))

    def _yield_nonmatching_pairs(self) -> Iterator[Tuple[NeuronKey, NeuronKey]]:
        """Yield all index pairs within all non-matching pairs."""
        if self._nonmatching_list is None:
            raise ValueError("Must provide non-matching pairs explicitly.")
        yield from yield_not_same(permutations(self._nonmatching_list, 2))

    def _yield_nonmatching_pairs_greedy(
        self, rng=None
    ) -> Iterator[Tuple[NeuronKey, NeuronKey]]:
        """Yield all index pairs within nonmatching list."""
        return yield_not_same(permutations(self.nonmatching, 2))

    def _yield_nonmatching_pairs_batched(self) -> Iterator[Tuple[NeuronKey, NeuronKey]]:
        """Yield all index pairs within nonmatching list.

        This function tries to generate truely random draws of all possible
        non-matching pairs without actually having to generate all pairs.
        Instead, we generate new randomly permutated pairs in batches from which
        we then remove previously seen pairs.

        This works reasonable well as long as we only need a small subset
        of all possible non-matches. Otherwise this becomes inefficient.

        """
        nonmatching = np.array(self.nonmatching)

        seen = []
        rng = np.random.default_rng(self.seed)

        # Generate random pairs
        pairs = np.vstack(
            (rng.permutation(nonmatching), rng.permutation(nonmatching))
        ).T
        pairs = pairs[pairs[:, 0] != pairs[:, 1]]  # drop self hits
        seen = set([tuple(p) for p in pairs])  # track already seen pairs
        i = 0
        while True:
            # If exhausted, generate a new batch of random permutation
            if i >= len(pairs):
                pairs = np.vstack(
                    (rng.permutation(nonmatching), rng.permutation(nonmatching))
                ).T
                pairs = pairs[pairs[:, 0] != pairs[:, 1]]  # drop self hits
                pairs = set([tuple(p) for p in pairs])
                pairs = pairs - seen
                seen = seen | pairs
                pairs = list(pairs)
                i = 0

            # Pick a pair
            ix1, ix2 = pairs[i]
            i += 1

            yield (ix1, ix2)

    def _empty_counts(self) -> np.ndarray:
        """Create an empty array in which to store counts; shape determined by digitizer sizes."""
        shape = [len(b) for b in self.digitizers]
        return np.zeros(shape, int)

    def _query(self, q_idx, t_idx) -> List[np.ndarray]:
        """Get the results of applying the match function to objects specified by indices."""
        return self.match_fn(self.objects[q_idx], self.objects[t_idx])

    def _query_many(self, idx_pairs, threads=None) -> Iterator[List[np.ndarray]]:
        """Yield results from querying many pairs of neuron indices."""
        if threads is None or (threads == 0 and cpu_count == 1):
            for q_idx, t_idx in idx_pairs:
                yield self._query(q_idx, t_idx)
            return

        threads = threads or cpu_count
        idx_pairs = np.asarray(idx_pairs)
        chunks = chunksize(len(idx_pairs), threads)

        with ProcessPoolExecutor(threads) as exe:
            yield from exe.map(
                self.match_fn,
                [self.objects[ix] for ix in idx_pairs[:, 0]],
                [self.objects[ix] for ix in idx_pairs[:, 1]],
                chunksize=chunks,
            )

    def _query_to_idxs(self, q_idx, t_idx, counts=None):
        """Produce a digitized counts array from a given query-target pair."""
        return self._count_results(self._query(q_idx, t_idx), counts)

    def _count_results(self, results: List[np.ndarray], counts=None):
        """Convert raw match function ouput into a digitized counts array.

        Requires digitizers.
        """
        # Digitize
        idxs = [dig(r) for dig, r in zip(self.digitizers, results)]

        # Make a stack
        stack = np.vstack(idxs).T

        # Create empty matrix if necessary
        if counts is None:
            counts = self._empty_counts()

        # Get counts per cell -> this is the actual bottleneck of this function
        cells, cnt = np.unique(stack, axis=0, return_counts=True)

        # Fill matrix
        counts[tuple(cells[:, i] for i in range(cells.shape[1]))] += cnt

        return counts

    def _counts_array(
        self,
        idx_pairs,
        threads=None,
        progress=True,
        desc=None,
    ):
        """Convert index pairs into a digitized counts array.

        Requires digitizers.
        """
        counts = self._empty_counts()
        if threads is None or (threads == 0 and cpu_count == 1):
            for q_idx, t_idx in config.tqdm(
                idx_pairs, leave=False, desc=desc, disable=not progress
            ):
                counts = self._query_to_idxs(q_idx, t_idx, counts)
            return counts

        threads = threads or cpu_count
        idx_pairs = np.asarray(idx_pairs, dtype=int)
        chunks = chunksize(len(idx_pairs), threads)

        # because digitizing is not necessarily free,
        # keep this parallelisation separate to that in _query_many
        with ProcessPoolExecutor(threads) as exe:
            # This is the progress bar
            with config.tqdm(
                desc=desc, total=len(idx_pairs), leave=False, disable=not progress
            ) as pbar:
                for distdots in exe.map(
                    self.match_fn,
                    [self.objects[ix] for ix in idx_pairs[:, 0]],
                    [self.objects[ix] for ix in idx_pairs[:, 1]],
                    chunksize=chunks,
                ):
                    counts = self._count_results(distdots, counts)
                    pbar.update(1)

        return counts

    def _pick_nonmatching_pairs(self, n_matching_qual_vals, progress=True):
        """Using the seeded RNG, pick which non-matching pairs to use."""
        # pre-calculating which pairs we're going to use,
        # rather than drawing them as we need them,
        # means that we can parallelise the later step more effectively.
        # Slowdowns here are practically meaningless
        # because of how long distdot calculation will take
        nonmatching_pairs = []
        n_nonmatching_qual_vals = 0
        if self.nonmatching_draw == "batched":
            # This is a generator that tries to generate random pairs in
            # batches to avoid having to calculate all possible pairs
            gen = self._yield_nonmatching_pairs_batched()
            with config.tqdm(
                desc="Drawing non-matching pairs",
                total=n_matching_qual_vals,
                leave=False,
                disable=not progress,
            ) as pbar:
                # Draw non-matching pairs until we have enough data
                for nonmatching_pair in gen:
                    nonmatching_pairs.append(nonmatching_pair)
                    new_vals = len(self.objects[nonmatching_pair[0]])
                    n_nonmatching_qual_vals += new_vals

                    pbar.update(new_vals)

                    if n_nonmatching_qual_vals >= n_matching_qual_vals:
                        break
        elif self.nonmatching_draw == "greedy":
            # Generate all possible non-matching pairs
            possible_pairs = len(self.nonmatching) ** 2 - len(self.nonmatching)
            all_nonmatching_pairs = [
                p
                for p in config.tqdm(
                    self._yield_nonmatching_pairs_greedy(),
                    total=possible_pairs,
                    desc="Generating non-matching pairs",
                )
            ]
            # Randomly pick non-matching pairs until we have enough data
            rng = np.random.default_rng(self.seed)
            with config.tqdm(
                desc="Drawing non-matching pairs",
                total=n_matching_qual_vals,
                leave=False,
                disable=not progress,
            ) as pbar:
                while n_nonmatching_qual_vals < n_matching_qual_vals:
                    idx = rng.integers(0, len(all_nonmatching_pairs))
                    nonmatching_pair = all_nonmatching_pairs.pop(idx)
                    nonmatching_pairs.append(nonmatching_pair)

                    new_vals = len(self.objects[nonmatching_pair[0]])
                    n_nonmatching_qual_vals += new_vals

                    pbar.update(new_vals)
        else:
            raise ValueError(
                "Unknown strategy for non-matching pair draw:"
                f"{self.nonmatching_draw}"
            )

        return nonmatching_pairs

    def _get_pairs(self):
        matching_pairs = list(set(self._yield_matching_pairs()))

        # If no explicit non-matches provided, pick them from the entire pool
        if self._nonmatching_list is None:
            # need to know the eventual distdot count
            # so we know how many non-matching pairs to draw
            q_idx_count = Counter(p[0] for p in matching_pairs)
            n_matching_qual_vals = sum(
                len(self.objects[q_idx]) * n_reps
                for q_idx, n_reps in q_idx_count.items()
            )

            nonmatching_pairs = self._pick_nonmatching_pairs(n_matching_qual_vals)
        else:
            nonmatching_pairs = list(set(self._yield_nonmatching_pairs()))

        return matching_pairs, nonmatching_pairs

    def _build(self, threads, progress=True) -> Tuple[List[Digitizer], np.ndarray]:
        # Asking for more threads than available CPUs seems to crash on Github
        # actions
        if threads and threads >= cpu_count:
            threads = cpu_count

        if self.digitizers is None and self.bin_counts is None:
            raise ValueError(
                "Builder needs either digitizers or bin_counts - " "see with_* methods."
            )

        self.matching_pairs, self.nonmatching_pairs = self._get_pairs()

        logger.info("Comparing matching pairs")
        if self.digitizers:
            self.match_counts_ = self._counts_array(
                self.matching_pairs,
                threads=threads,
                progress=progress,
                desc="Comparing matching pairs",
            )
        else:
            match_results = concat_results(
                self._query_many(self.matching_pairs, threads),
                progress=progress,
                desc="Comparing matching pairs",
                total=len(self.matching_pairs),
            )
            self.match_results_ = match_results
            self.digitizers = []
            for i, (data, nbins) in enumerate(zip(match_results, self.bin_counts)):
                if not isinstance(nbins, Digitizer):
                    try:
                        self.digitizers.append(
                            Digitizer.from_data(data, nbins, method=self.bin_method)
                        )
                    except BaseException as e:
                        logger.error(f"Error creating digitizers for axes {i + 1}")
                        raise e
                else:
                    self.digitizers.append(nbins)

            logger.info("Counting results (this may take a while)")
            self.match_counts_ = self._count_results(match_results)

        logger.info("Comparing non-matching pairs")
        self.nonmatch_counts_ = self._counts_array(
            self.nonmatching_pairs,
            threads=threads,
            progress=progress,
            desc="Comparing non-matching pairs",
        )

        # Account for there being different total numbers of datapoints for
        # matches and nonmatches
        self.matching_factor_ = self.nonmatch_counts_.sum() / self.match_counts_.sum()
        if np.any(self.match_counts_ + self.nonmatch_counts_ == 0):
            logger.warning("Some lookup cells have no data in them")

        self.cells_ = np.log2(
            (self.match_counts_ * self.matching_factor_ + epsilon)
            / (self.nonmatch_counts_ + epsilon)
        )

        return self.digitizers, self.cells_

    def build(self, threads=None) -> LookupNd:
        """Build the score matrix.

        All non-identical neuron pairs within all matching sets are selected,
        and the scoring function is evaluated for those pairs.
        Then, the minimum number of non-matching pairs are randomly drawn
        so that at least as many data points can be calculated for non-matching
        pairs.

        In each bin of the score matrix, the log2 odds ratio of a score
        in that bin belonging to a match vs. non-match is calculated.

        Parameters
        ----------
        threads :   int, optional
                    If None, act in serial.
                    If 0, use cpu_count - 1.
                    Otherwise, use the given value.
                    Will be clipped at number of available cores - 1.
                    Note that with the currently implementation a large number
                    of threads might (and somewhat counterintuitively) actually
                    be slower than running building the scoring function in serial.

        Returns
        -------
        LookupNd
        """
        dig, cells = self._build(threads)
        return LookupNd(dig, cells)

Indices of nonmatching set of neurons.

Build the score matrix.

All non-identical neuron pairs within all matching sets are selected, and the scoring function is evaluated for those pairs. Then, the minimum number of non-matching pairs are randomly drawn so that at least as many data points can be calculated for non-matching pairs.

In each bin of the score matrix, the log2 odds ratio of a score in that bin belonging to a match vs. non-match is calculated.

PARAMETER DESCRIPTION
threads
    If None, act in serial.
    If 0, use cpu_count - 1.
    Otherwise, use the given value.
    Will be clipped at number of available cores - 1.
    Note that with the currently implementation a large number
    of threads might (and somewhat counterintuitively) actually
    be slower than running building the scoring function in serial.

TYPE: int DEFAULT: None

RETURNS DESCRIPTION
LookupNd
Source code in navis/nbl/smat.py
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
def build(self, threads=None) -> LookupNd:
    """Build the score matrix.

    All non-identical neuron pairs within all matching sets are selected,
    and the scoring function is evaluated for those pairs.
    Then, the minimum number of non-matching pairs are randomly drawn
    so that at least as many data points can be calculated for non-matching
    pairs.

    In each bin of the score matrix, the log2 odds ratio of a score
    in that bin belonging to a match vs. non-match is calculated.

    Parameters
    ----------
    threads :   int, optional
                If None, act in serial.
                If 0, use cpu_count - 1.
                Otherwise, use the given value.
                Will be clipped at number of available cores - 1.
                Note that with the currently implementation a large number
                of threads might (and somewhat counterintuitively) actually
                be slower than running building the scoring function in serial.

    Returns
    -------
    LookupNd
    """
    dig, cells = self._build(threads)
    return LookupNd(dig, cells)

Specify the number of bins on each axis of the output lookup table.

The bin boundaries will be determined by evenly partitioning the data from the matched pairs into quantiles, in each dimension.

PARAMETER DESCRIPTION
bin_counts

TYPE: List[int]

method
     Method used to tile the data space.

TYPE: 'quantile' | 'geometric' | 'linear' DEFAULT: 'quantile'

RETURNS DESCRIPTION
self

For chaining convenience.

Source code in navis/nbl/smat.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def with_bin_counts(self, bin_counts: List[int], method="quantile"):
    """Specify the number of bins on each axis of the output lookup table.

    The bin boundaries will be determined by evenly partitioning the data
    from the matched pairs into quantiles, in each dimension.

    Parameters
    ----------
    bin_counts : List[int]
    method :     'quantile' | 'geometric' | 'linear'
                 Method used to tile the data space.

    Returns
    -------
    self
        For chaining convenience.
    """
    if len(bin_counts) != self.ndim:
        raise ValueError(
            f"Match function returns {self.ndim} values "
            f"but provided {len(bin_counts)} bin counts"
        )

    self.bin_counts = bin_counts
    self.digitizers = None
    self.bin_method = method
    return self

Specify the axes of the output lookup table directly.

PARAMETER DESCRIPTION
digitizers

TYPE: List[Digitizer]

RETURNS DESCRIPTION
self

For chaining convenience.

Source code in navis/nbl/smat.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def with_digitizers(self, digitizers: List[Digitizer]):
    """Specify the axes of the output lookup table directly.

    Parameters
    ----------
    digitizers : List[Digitizer]

    Returns
    -------
    self
        For chaining convenience.
    """
    if len(digitizers) != self.ndim:
        raise ValueError(
            f"Match function returns {self.ndim} values "
            f"but provided {len(digitizers)} digitizers"
        )

    self.digitizers = digitizers
    self.bin_counts = None
    return self

Look up in a list of items and return their index.

PARAMETER DESCRIPTION
items

The item's position in the list is the index which will be returned.

TYPE: List[Hashable]

RAISES DESCRIPTION
ValueError

items are non-unique.

Source code in navis/nbl/smat.py
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
class SimpleLookup(LookupAxis[Hashable]):
    """Look up in a list of items and return their index.

    Parameters
    ----------
    items : List[Hashable]
        The item's position in the list is the index which will be returned.

    Raises
    ------
    ValueError
        items are non-unique.
    """

    def __init__(self, items: List[Hashable]):
        self.items = {item: idx for idx, item in enumerate(items)}
        if len(self.items) != len(items):
            raise ValueError("Items are not unique")

    def __len__(self) -> int:
        return len(self.items)

    def __call__(
        self, value: Union[Hashable, Sequence[Hashable]]
    ) -> Union[int, Sequence[int]]:
        if np.isscalar(value):
            return self.items[value]
        else:
            return np.array([self.items[v] for v in value], int)

Checks functionally that the callable can be used as a score function.

PARAMETER DESCRIPTION
nargs

How many positional arguments the score function should have.

TYPE: optional int DEFAULT: 2

scalar

Check that the function can be used on nargs scalars.

TYPE: optional bool DEFAULT: True

array

Check that the function can be used on nargs 1D numpy.ndarrays.

TYPE: optional bool DEFAULT: True

RAISES DESCRIPTION
ValueError

If the score function is not appropriate.

Source code in navis/nbl/smat.py
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
def check_score_fn(fn: Callable, nargs=2, scalar=True, array=True):
    """Checks functionally that the callable can be used as a score function.

    Parameters
    ----------
    nargs : optional int, default 2
        How many positional arguments the score function should have.
    scalar : optional bool, default True
        Check that the function can be used on `nargs` scalars.
    array : optional bool, default True
        Check that the function can be used on `nargs` 1D `numpy.ndarray`s.

    Raises
    ------
    ValueError
        If the score function is not appropriate.
    """
    if scalar:
        scalars = [0.5] * nargs
        if not isinstance(fn(*scalars), float):
            raise ValueError("smat does not take 2 floats and return a float")

    if array:
        test_arr = np.array([0.5] * 3)
        arrs = [test_arr] * nargs
        try:
            out = fn(*arrs)
        except Exception as e:
            raise ValueError(f"Failed to use smat with numpy arrays: {e}")

        if out.shape != test_arr.shape:
            raise ValueError(
                f"smat produced inconsistent shape: input {test_arr.shape}; output {out.shape}"
            )
Source code in navis/nbl/smat.py
50
51
def chunksize(it_len, cpu_count, min_chunk=50):
    return max(min_chunk, int(it_len / (cpu_count * 4)))

Helper function to concatenate batches of e.g. [(dist, dots), (dist, dots)] into single (dist, dot) arrays.

Source code in navis/nbl/smat.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def concat_results(
    results: Iterable[List[np.ndarray]],
    total: Optional[int] = None,
    desc: str = "Querying",
    progress: bool = True,
) -> List[np.ndarray]:
    """Helper function to concatenate batches of e.g. [(dist, dots), (dist, dots)]
    into single (dist, dot) arrays.
    """
    intermediate = defaultdict(list)
    with config.tqdm(desc=desc, total=total, leave=False, disable=not progress) as pbar:
        for result_lst in results:
            for idx, array in enumerate(result_lst):
                intermediate[idx].append(array)
            pbar.update(1)

    return [np.concatenate(arrs) for arrs in intermediate.values()]
Source code in navis/nbl/smat.py
587
588
def dist_dot(q: "core.Dotprops", t: "core.Dotprops"):
    return list(q.dist_dots(t))
Source code in navis/nbl/smat.py
591
592
593
def dist_dot_alpha(q: "core.Dotprops", t: "core.Dotprops"):
    dist, dot, alpha = q.dist_dots(t, alpha=True)
    return [dist, dot * np.sqrt(alpha)]
Source code in navis/nbl/smat.py
658
659
660
661
662
def is_monotonically_increasing(lst):
    for prev_idx, item in enumerate(lst[1:]):
        if item <= lst[prev_idx]:
            return False
    return True
Source code in navis/nbl/smat.py
665
666
667
668
669
670
671
672
673
674
675
def parse_boundary(item: str):
    explicit_interval = item[0] + item[-1]
    if explicit_interval == "[)":
        right = False
    elif explicit_interval == "(]":
        right = True
    else:
        raise ValueError(
            f"Enclosing characters '{explicit_interval}' do not match a half-open interval"
        )
    return tuple(float(i) for i in item[1:-1].split(",")), right
Source code in navis/nbl/smat.py
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
def parse_score_fn(smat, alpha=False):
    f"""Interpret `smat` as a score function.
    Primarily for backwards compatibility.
    {SCORE_FN_DESCR}
    Parameters
    ----------
    smat : None | "auto" | str | os.PathLike | pandas.DataFrame | Callable[[float, float], float]
        If `None`, use `operator.mul`.
        If `"auto"`, use `navis.nbl.smat.smat_fcwb(alpha)`.
        If a dataframe, use `navis.nbl.smat.Lookup2d.from_dataframe(smat)`.
        If another string or path-like, load from CSV in a dataframe and uses as above.
        Also checks the signature of the callable.
        Raises an error, probably a ValueError, if it can't be interpreted.
    alpha : optional bool, default False
        If `smat` is `"auto"`, choose whether to use the FCWB matrices
        with or without alpha.
    Returns
    -------
    Callable
    Raises
    ------
    ValueError
        If score function cannot be interpreted.
    """
    if smat is None:
        smat = operator.mul
    elif smat == "auto":
        smat = smat_fcwb(alpha)

    if isinstance(smat, (str, os.PathLike)):
        smat = pd.read_csv(smat, index_col=0)

    if isinstance(smat, pd.DataFrame):
        smat = Lookup2d.from_dataframe(smat)

    if not callable(smat):
        raise ValueError(
            "smat should be a callable, a path, a pandas.DataFrame, or 'auto'"
        )

    check_score_fn(smat)

    return smat
Source code in navis/nbl/smat.py
54
55
56
57
def yield_not_same(pairs: Iterable[Tuple[Any, Any]]) -> Iterator[Tuple[Any, Any]]:
    for a, b in pairs:
        if a != b:
            yield a, b