"""Utility functions."""
import numpy as np
from numba import jit, int32, types
__all__ = ['resize_axis', 'tri_ravel', 'unravel_lm', 'ravel_lm', 'get_grid_sampling']
[docs]def resize_axis(arr, size, mode='zero', axis=0):
"""
Resize an axis of the array, either truncating or zero-padding.
The "mode" keyword determines how this is done.
Parameters
----------
arr: ndarray
Array to resize
size: int
New size for the axis.
mode: str
If the new array is zero-padded, where to put the old data:
* 'zero' : Put zeros in the middle of the axis (center data on zero)
* 'start' : Put zeros at the end of the axis.
* 'center': Evenly fill data on both sides.
Defaults to "zero".
axis: int
Which axis to resize
Returns
-------
arr: ndarray
Input array with the specified axis padded or truncated.
Notes
-----
The 'zero' mode can be used to zero-pad an FFT-transformed array before
applying an inverse transform.
"""
shape = list(arr.shape)
shape[axis] = size
new = np.zeros(tuple(shape), dtype=arr.dtype)
oldodd = arr.shape[axis] % 2 == 1
newodd = size % 2 == 1
L = arr.shape[axis]
if oldodd:
center = (L - 1) // 2 # This stays on the left.
else:
center = L // 2
if newodd:
newcent = (size - 1) // 2
else:
newcent = size // 2
_arr = np.swapaxes(arr, axis, 0)
_new = np.swapaxes(new, axis, 0)
if mode == 'zero':
limit = np.min([center, newcent])
if oldodd:
_new[:limit + 1, ...] = _arr[:limit + 1, ...]
_new[-1:-limit - 1:-1, ...] = _arr[-1:-limit - 1:-1, ...]
else:
_new[:limit, ...] = _arr[:limit, ...]
_new[-1:-limit - 1:-1, ...] = _arr[-1:-limit - 1:-1, ...]
elif mode == 'start':
limit = np.min([size, L])
_new[:limit, ...] = _arr[:limit, ...]
elif mode == 'center':
if size <= L:
# Truncating
base = int(np.floor((L - size) / 2))
_new[:, ...] = _arr[base:base + size, ...]
else:
base = int(np.floor((size - L) / 2))
_new[base:base + L, ...] = _arr[:, ...]
new = np.swapaxes(_new, 0, axis)
return new
# Index raveling/unraveling
[docs]@jit(int32(int32, int32, int32), nopython=True)
def tri_ravel(l, m1, m2):
"""Ravel indices for the 'stack of triangles' ordering."""
# m1 must be >= m2
if m1 < m2 or m1 > l or m2 > l or m1 < 0 or m2 < 0:
raise ValueError("Invalid indices")
base = l * (l + 1) * (l + 2) // 6
offset = (l - m1) * (l + 3 + m1) // 2 + m2
ind = base + offset
return int(ind)
@jit(int32(int32), nopython=True)
def tri_base(l):
"""Minimum index for a given el block."""
return tri_ravel(l, l, 0)
@jit(int32(int32), nopython=True)
def el_block_size(l):
"""Size needed for a given el in the dmat array."""
return (l + 1) * (l + 2) // 2
[docs]@jit(int32(int32, int32), nopython=True)
def unravel_lm(el, m):
"""Get index from (el, em)."""
return el * (el + 1) + m
[docs]@jit(types.Tuple((int32, int32))(int32), nopython=True)
def ravel_lm(ind):
"""Get (el, em) from index."""
el = int(np.floor(np.sqrt(ind)))
m = ind - el * (el + 1)
return el, m
[docs]def get_grid_sampling(lmax=None, Nt=None, Nf=None):
"""
Get sample positions for "grid", compatible with the MW transform methods.
Parameters
----------
lmax: int
Maximum multipole moment (Optional if Nt/Nf are set.)
Nt: int
Number of samples in theta. (Optional, defaults to lmax)
Nf: int
Number of samples in phi for all colatitudes
(Optional, defaults to (2 * lmax - 1))
Returns
-------
thetas: ndarray
Colatitudes in radians.
phis: ndarray
Azimuths in radians.
"""
if (lmax is None) and (Nt is None or Nf is None):
raise ValueError("Need to provide lmax if Nt and Nf are unset.")
if Nt is None:
Nt = lmax
if Nf is None:
if lmax is not None:
Nf = 2 * lmax - 1
dth = np.pi / (2 * Nt - 1)
thetas = np.linspace(dth, np.pi, Nt, endpoint=True)
phis = np.linspace(0, 2 * np.pi, Nf, endpoint=False)
return thetas, phis