Skip to content

base

Optimizes a Transform or TransformSequence.

The purpose of this class is to change a bunch of settings (depending on the type of transforms) before running the transformation and do a clean up after it has finished.

Currently, the optimizations are very hands-on but in the future, this might be delegated to the individuals Transform classes - e.g. by implementing an .optimize() method which is then called by TransOptimizer.

Currently, it really only manages caching for H5 transforms.

PARAMETER DESCRIPTION
tr
    The transform or sequence thereof to be optimized.

TYPE: Transform | TransformSequence

mode
    Mode for optimization:
      - `None`: no optimization
      - "medium": some optimization but keep upfront cost low
      - "aggressive": high upfront cost but should be faster in the long run

TYPE: None | "medium" | "aggressive"

Examples:

>>> from navis.transforms import h5reg
>>> from navis.transforms.base import TransOptimizer
>>> tr = h5reg.H5transform('path/to/reg.h5', direction='inverse')
>>> with TransOptimizer(tr, mode='aggressive'):
>>>     xf = tr.xform(pts)
Source code in navis/transforms/base.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
class TransOptimizer:
    """Optimizes a Transform or TransformSequence.

    The purpose of this class is to change a bunch of settings (depending on the
    type of transforms) before running the transformation and do a clean up
    after it has finished.

    Currently, the optimizations are very hands-on but in the future, this might
    be delegated to the individuals `Transform` classes - e.g. by implementing
    an `.optimize()` method which is then called by `TransOptimizer`.

    Currently, it really only manages caching for H5 transforms.

    Parameters
    ----------
    tr :        Transform | TransformSequence
                The transform or sequence thereof to be optimized.
    mode :      None | "medium" | "aggressive"
                Mode for optimization:
                  - `None`: no optimization
                  - "medium": some optimization but keep upfront cost low
                  - "aggressive": high upfront cost but should be faster in the long run

    Examples
    --------
    >>> from navis.transforms import h5reg
    >>> from navis.transforms.base import TransOptimizer
    >>> tr = h5reg.H5transform('path/to/reg.h5', direction='inverse') # doctest: +SKIP
    >>> with TransOptimizer(tr, mode='aggressive'):                   # doctest: +SKIP
    >>>     xf = tr.xform(pts)                                        # doctest: +SKIP

    """

    def __init__(self, tr, bbox, caching: bool):
        """Initialize Optimizer."""
        assert isinstance(caching, bool)

        self.caching = caching
        self.bbox = np.asarray(bbox)

        assert self.bbox.ndim == 2 and self.bbox.shape == (3, 2)

        if isinstance(tr, BaseTransform):
            self.transforms = [tr]
        elif isinstance(tr, TransformSequence):
            self.transforms = tr.transforms
        else:
            raise TypeError(f'Expected Transform/Sequence, got "{type(tr)}"')

    def __enter__(self):
        """Apply optimizations."""
        if not self.caching:
            return

        # Check if there are any transforms we can optimize
        if not any(['H5transform' in str(type(tr)) for tr in self.transforms]):
            return

        if not config.pbar_hide:
            logger.info('Pre-caching deformation field(s) for transforms...')

        bbox_xf = self.bbox
        for tr in self.transforms:
            # We are monkey patching here to avoid circular imports
            # not pretty but works
            if 'H5transform' in str(type(tr)):
                # Precache values in the bounding box
                tr.precache(bbox_xf, padding=True)
            # To pre-cache sequential transforms we need to xform the bounding
            # box as we move along
            bbox_xf = tr.xform(bbox_xf.T).T

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Revert optimizations."""
        if not self.caching:
            return

        for tr in self.transforms:
            # We are monkey patching here to avoid circular imports
            # not pretty but works
            if 'H5transform' in str(type(tr)):
                # Clears the cache
                tr.use_cache = False

Initialize Optimizer.

Source code in navis/transforms/base.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def __init__(self, tr, bbox, caching: bool):
    """Initialize Optimizer."""
    assert isinstance(caching, bool)

    self.caching = caching
    self.bbox = np.asarray(bbox)

    assert self.bbox.ndim == 2 and self.bbox.shape == (3, 2)

    if isinstance(tr, BaseTransform):
        self.transforms = [tr]
    elif isinstance(tr, TransformSequence):
        self.transforms = tr.transforms
    else:
        raise TypeError(f'Expected Transform/Sequence, got "{type(tr)}"')

A sequence of transforms.

Use this to apply multiple (different types of) transforms in sequence.

PARAMETER DESCRIPTION
*transforms
        The transforms to bundle in this sequence.

TYPE: Transform/Sequences. DEFAULT: ()

copy
        Whether to make a copy of the transform on initialization.
        This is highly recommended because otherwise we might alter
        the original as we add more transforms (e.g. for CMTK
        transforms).

TYPE: bool DEFAULT: True

Source code in navis/transforms/base.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
class TransformSequence:
    """A sequence of transforms.

    Use this to apply multiple (different types of) transforms in sequence.

    Parameters
    ----------
    *transforms :   Transform/Sequences.
                    The transforms to bundle in this sequence.
    copy :          bool
                    Whether to make a copy of the transform on initialization.
                    This is highly recommended because otherwise we might alter
                    the original as we add more transforms (e.g. for CMTK
                    transforms).

    """

    def __init__(self, *transforms, copy=True):
        """Initialize."""
        self.transforms = []
        for tr in transforms:
            if not isinstance(tr, (BaseTransform, TransformSequence)):
                raise TypeError(f'Expected transform, got "{type(tr)}"')
            if copy:
                tr = tr.copy()
            self.append(tr)

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        return f'TransformSequence with {len(self)} transform(s)'

    def __len__(self) -> int:
        """Count number of transforms in this sequence."""
        return len(self.transforms)

    def __neg__(self) -> 'TransformSequence':
        """Invert transform sequence."""
        return TransformSequence(*[-t for t in self.transforms[::-1]])

    def append(self, transform: 'BaseTransform'):
        """Add transform to list."""
        if isinstance(transform, TransformSequence):
            # Unpack if other is sequence of transforms
            transform = transform.transforms

        for tr in utils.make_iterable(transform):
            if not isinstance(tr, BaseTransform):
                raise TypeError(f'Unable append "{type(tr)}"')

            if not hasattr(transform, 'xform') or not callable(transform.xform):
                raise TypeError('Transform does not appear to have a `xform` method')

            # Try to merge with the last transform in the sequence
            if len(self):
                try:
                    self.transforms[-1].append(tr)
                except NotImplementedError:
                    self.transforms.append(tr)
                except BaseException:
                    raise
            else:
                self.transforms.append(tr)

    def xform(self, points: np.ndarray,
              affine_fallback: bool = True,
              **kwargs) -> np.ndarray:
        """Perform transforms in sequence."""
        # First check if any of the transforms raise any issues ahead of time
        # This can e.g. be missing binaries like CMTK's streamxform
        for tr in self.transforms:
            tr.check_if_possible(on_error='raise')

        # Now transform points in sequence
        # Make a copy of the points to avoid changing the originals
        # Note dtype float64 in case our precision in case precisio must go up
        # -> e.g. when converting from nm to micron space
        xf = np.asarray(points).astype(np.float64)
        for tr in self.transforms:
            # Check this transforms signature for accepted Parameters
            params = signature(tr.xform).parameters

            # We must not pass NaN value from one transform to the next
            is_nan = np.any(np.isnan(xf), axis=1)

            # Skip if all points are NaN
            if all(is_nan):
                continue

            if 'affine_fallback' in params:
                xf[~is_nan] = tr.xform(xf[~is_nan],
                                       affine_fallback=affine_fallback,
                                       **kwargs)
            else:
                xf[~is_nan] = tr.xform(xf[~is_nan], **kwargs)

        return xf

Initialize.

Source code in navis/transforms/base.py
178
179
180
181
182
183
184
185
186
def __init__(self, *transforms, copy=True):
    """Initialize."""
    self.transforms = []
    for tr in transforms:
        if not isinstance(tr, (BaseTransform, TransformSequence)):
            raise TypeError(f'Expected transform, got "{type(tr)}"')
        if copy:
            tr = tr.copy()
        self.append(tr)

Add transform to list.

Source code in navis/transforms/base.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def append(self, transform: 'BaseTransform'):
    """Add transform to list."""
    if isinstance(transform, TransformSequence):
        # Unpack if other is sequence of transforms
        transform = transform.transforms

    for tr in utils.make_iterable(transform):
        if not isinstance(tr, BaseTransform):
            raise TypeError(f'Unable append "{type(tr)}"')

        if not hasattr(transform, 'xform') or not callable(transform.xform):
            raise TypeError('Transform does not appear to have a `xform` method')

        # Try to merge with the last transform in the sequence
        if len(self):
            try:
                self.transforms[-1].append(tr)
            except NotImplementedError:
                self.transforms.append(tr)
            except BaseException:
                raise
        else:
            self.transforms.append(tr)

Perform transforms in sequence.

Source code in navis/transforms/base.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def xform(self, points: np.ndarray,
          affine_fallback: bool = True,
          **kwargs) -> np.ndarray:
    """Perform transforms in sequence."""
    # First check if any of the transforms raise any issues ahead of time
    # This can e.g. be missing binaries like CMTK's streamxform
    for tr in self.transforms:
        tr.check_if_possible(on_error='raise')

    # Now transform points in sequence
    # Make a copy of the points to avoid changing the originals
    # Note dtype float64 in case our precision in case precisio must go up
    # -> e.g. when converting from nm to micron space
    xf = np.asarray(points).astype(np.float64)
    for tr in self.transforms:
        # Check this transforms signature for accepted Parameters
        params = signature(tr.xform).parameters

        # We must not pass NaN value from one transform to the next
        is_nan = np.any(np.isnan(xf), axis=1)

        # Skip if all points are NaN
        if all(is_nan):
            continue

        if 'affine_fallback' in params:
            xf[~is_nan] = tr.xform(xf[~is_nan],
                                   affine_fallback=affine_fallback,
                                   **kwargs)
        else:
            xf[~is_nan] = tr.xform(xf[~is_nan], **kwargs)

    return xf

Trigger delayed initialization.

Source code in navis/transforms/base.py
27
28
29
30
31
32
33
34
35
36
def trigger_init(func):
    """Trigger delayed initialization."""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        self = args[0]
        # Check if has already been initialized
        if not self.initialized:
            self.__delayed_init__()
        return func(*args, **kwargs)
    return wrapper