Source code for gumerov_expansion_coefficients._main

# https://github.com/search?q=gumerov+translation+language%3APython&type=code&l=Python
from math import prod, sqrt
from typing import Any

import numba
from array_api._2024_12 import Array
from array_api_compat import array_namespace
from numba import complex64, complex128, float32, float64, jit, prange
from numba.cuda.cudadrv.error import CudaSupportError

from gumerov_expansion_coefficients._elementary_solutions import RS_all, idx_all, minus_1_power

# Gumerov's notation
# E^m_n = sum_{m'n'} (E|F)^{m' m}_{n' n} F^{m'}_{n'}
# (E|F)^{m' m}_{n'} := (E|F)^{m' m}_{n' |m|}
# (E|F)^{m' m}_{,n} := (E|F)^{m' m}_{|m'| n}


@jit()
def idx_i(n: int, m: int, /) -> int:
    """Index for the coefficients."""
    # (0, 0) -> 0
    # (1, -1) -> 1
    # (1, 0) -> 2
    return n**2 + (m % (2 * n + 1))


def idx(n: Array, m: Array, /) -> Array:
    """Index for the coefficients."""
    # (0, 0) -> 0
    # (1, -1) -> 1
    # (1, 0) -> 2
    xp = array_namespace(n, m)
    m_abs = xp.abs(m)
    return xp.where(m_abs > n, -1, n**2 + (m % (2 * n + 1)))


@jit(inline="always")
def ndim_harm(n_end: int, /) -> int:
    """Number of spherical harmonics which degree is less than n_end."""
    return n_end**2


minus_1_power_jit = jit(inline="always")(minus_1_power)


def translational_coefficients_sectorial_init(
    kr: Array, theta: Array, phi: Array, same: bool, n_end: int, /
) -> Array:
    """Initial values of sectorial translational coefficients (E|F)^{m',0}_{n', 0}

    Parameters
    ----------
    kr : Array
        k * r of shape (...,)
    theta : Array
        polar angle of shape (...,)
    phi : Array
        azimuthal angle of shape (...,)
    same : bool
        If True, return (R|R) = (S|S).
        If False, return (S|R).
    n_end : int
        Maximum degree of spherical harmonics.

    Returns
    -------
    Array
        Initial sectorial translational coefficients of shape (..., ndim_harm(n_end),)
    """
    xp = array_namespace(kr, theta, phi)
    n, m = idx_all(2 * n_end - 1, xp=xp, dtype=xp.int32, device=kr.device)
    # 4.43 / 4.58
    # (E|F)^{m' 0}_{n' 0} = (E|F)^{m' 0}_{n'}
    return (
        minus_1_power(n)
        * xp.sqrt(xp.asarray(4.0, dtype=kr.dtype, device=kr.device) * xp.pi)
        * RS_all(kr, theta, phi, n_end=2 * n_end - 1, type="regular" if same else "singular")[
            ..., idx(n, -m)
        ]
    )


@jit(inline="always")
def _set_coef(a: Array, nd: int, md: int, n: int, m: int, value: float, swap: bool = False) -> None:
    if swap:
        a[idx_i(n, m), idx_i(nd, md)] = value
    else:
        a[idx_i(nd, md), idx_i(n, m)] = value


@jit(inline="always")
def _get_coef(a: Array, nd: int, md: int, n: int, m: int, swap: bool = False) -> float:
    if swap:
        return a[idx_i(n, m), idx_i(nd, md)]
    else:
        return a[idx_i(nd, md), idx_i(n, m)]


def _translational_coefficients_all(
    translational_coefficients_sectorial_init: Array, ret: Array, _: Array, /
) -> None:
    """Translational coefficients (E|F)^{m',m}_{n',n}

    Parameters
    ----------
    translational_coefficients_sectorial_init : Array
        Initial sectorial translational coefficients (E|F)^{m',0}_{n', 0}
        of shape (..., ndim_harm(2 * n_end - 1),)
    ret : Array
        Empty array to store the result of shape (..., ndim_harm(n_end), ndim_harm(n_end))
    _ : Array
        Dummy return array for numba guvectorize

    """
    n_end = (int(sqrt(ret.shape[-1])) + 1) // 2
    for nd in prange(2 * n_end - 1):
        for md in prange(-nd, nd + 1):
            _set_coef(ret, nd, md, 0, 0, translational_coefficients_sectorial_init[idx_i(nd, md)])

    for m in range(2 * n_end - 2):
        n = abs(m)
        for nd in prange(2 * n_end - n - 2):
            for md in prange(-nd, nd + 1):
                tmp = -b(nd + 1, md - 1) * _get_coef(ret, nd + 1, md - 1, n, m)  # 3rd term
                if abs(md - 1) <= nd - 1:
                    tmp += b(nd, -md) * _get_coef(ret, nd - 1, md - 1, n, m)  # 4th term
                tmp /= b(n + 1, -m - 1)
                _set_coef(ret, nd, md, n + 1, m + 1, tmp)  # 2nd term

    for m in range(2 * n_end - 2):
        m = -m
        n = abs(m)
        for nd in prange(2 * n_end - n - 2):
            for md in prange(-nd, nd + 1):
                tmp = b(nd, md) * _get_coef(ret, nd - 1, md + 1, n, m)  # 1st term
                if abs(md + 1) <= nd + 1:
                    tmp -= b(nd + 1, -md - 1) * _get_coef(ret, nd + 1, md + 1, n, m)  # 2nd term
                tmp /= b(n + 1, m - 1)
                _set_coef(ret, nd, md, n + 1, m - 1, tmp)  # 3rd term

    for m in prange(-2 * n_end + 2, 2 * n_end - 1):
        n = abs(m)
        for nd in prange(2 * n_end - 1):
            for md in prange(-nd, nd + 1):
                _set_coef(
                    ret,
                    nd,
                    md,
                    n,
                    m,
                    float32(minus_1_power_jit(n + nd)) * _get_coef(ret, nd, -md, n, -m),
                    swap=True,
                )

    for m in prange(-n_end + 1, n_end):
        for md in prange(-n_end + 1, n_end):
            mabs, mdabs = abs(m), abs(md)
            mlarger = max(mabs, mdabs)
            n_iter = n_end - mlarger - 1

            for m1_is_md, m1, m2 in ((True, md, m), (False, m, md)):
                # del m, md
                m1abs = abs(m1)
                m2abs = abs(m2)
                for i in range(n_iter):
                    n2 = i + m2abs  # n
                    for n1 in prange(m1abs + i + 1, 2 * n_end - mlarger - i - 2):  # nd
                        tmp = (
                            -a(n1, m1) * _get_coef(ret, n1 + 1, m1, n2, m2, swap=m1_is_md)  # 3rd
                            + a(n1 - 1, m1)
                            * _get_coef(ret, n1 - 1, m1, n2, m2, swap=m1_is_md)  # 4th
                        )
                        if i > 0:
                            tmp += (
                                a(n2 - 1, m2)
                                * _get_coef(ret, n1, m1, n2 - 1, m2, swap=m1_is_md)  # 1st
                            )  # 1st
                        _set_coef(ret, n1, m1, n2 + 1, m2, tmp / a(n2, m2), swap=m1_is_md)  # 2nd


_translational_coefficients_all_impl = {}

for dtype_f, dtype_c in ((float32, complex64), (float64, complex128)):

    @jit()
    def a(n: int, m: int, /) -> float:
        m_abs = abs(m)
        if m_abs > n:
            return dtype_f(0)  # noqa: B023
        return sqrt(dtype_f((n + m_abs + 1) * (n - m_abs + 1)) / dtype_f((2 * n + 1) * (2 * n + 3)))  # noqa: B023

    @jit()
    def b(n: int, m: int, /) -> float:
        m_abs = abs(m)
        if m_abs > n:
            return dtype_f(0)  # noqa: B023
        tmp = sqrt(dtype_f((n - m - 1) * (n - m)) / dtype_f((2 * n - 1) * (2 * n + 1)))  # noqa: B023
        if m >= 0:
            return tmp
        else:
            return -tmp

    _numba_args: tuple[list[Any], str] = (
        [(dtype_c[:], dtype_c[:, :], dtype_c)],
        "(n),(n,n)->()",
    )
    _translational_coefficients_all_impl[("parallel", dtype_c)] = numba.guvectorize(
        *_numba_args, target="parallel", cache=True
    )(_translational_coefficients_all)

    try:
        _translational_coefficients_all_impl[("cuda", dtype_c)] = numba.guvectorize(
            *_numba_args, target="cuda"
        )(_translational_coefficients_all)
    except CudaSupportError:
        _translational_coefficients_all_impl[("cuda", dtype_c)] = None


def translational_coefficients_all(
    *,
    n_end: int,
    translational_coefficients_sectorial_init: Array,
) -> Array:
    """Translational coefficients (E|F)^{m',m}_{n',n}

    Parameters
    ----------
    translational_coefficients_sectorial_init : Array
        Initial sectorial translational coefficients (E|F)^{m',0}_{n', 0}
        of shape (..., ndim_harm(n_end),)

    Returns
    -------
    Array
        Translational coefficients [(m',n'),(m,n)] of shape (ndim_harm(n_end), ndim_harm(n_end))
    """
    xp = array_namespace(translational_coefficients_sectorial_init)
    dtype = translational_coefficients_sectorial_init.dtype
    device = translational_coefficients_sectorial_init.device
    shape = translational_coefficients_sectorial_init.shape[:-1]
    ret = xp.zeros(
        (prod(shape), ndim_harm(2 * n_end - 1), ndim_harm(2 * n_end - 1)),
        dtype=dtype,
        device=device,
    )
    _translational_coefficients_all_impl[
        (
            "cuda" if "cuda" in str(device) else "parallel",
            complex128 if dtype == xp.complex128 else complex64,
        )
    ](
        translational_coefficients_sectorial_init,
        ret,
    )
    ret = xp.asarray(
        ret,
        dtype=dtype,
        device=device,
    )[..., : ndim_harm(n_end), : ndim_harm(n_end)]
    ret = xp.reshape(ret, (*shape, *ret.shape[-2:]))
    return ret


[docs] def translational_coefficients( kr: Array, theta: Array, phi: Array, *, same: bool, n_end: int ) -> Array: r"""Translational coefficients (E|F)^{m',m}_{n',n}. .. math:: Y_n^m (\theta, \phi) &:= (-1)^m \sqrt{\frac{(2n+1)(n-\left|m\right|)!}{4 \pi (n+\left|m\right|)!}} P_n^{\left|m\right|} (\cos \theta) e^{i m \phi} R_n^m (kr, \theta, \phi) &:= j_n(kr) Y_n^m (\theta, \phi) S_n^m (kr, \theta, \phi) &:= h_n^{(1)}(kr) Y_n^m (\theta, \phi) Parameters ---------- kr : Array k * r of shape (...,) theta : Array polar angle of shape (...,) phi : Array azimuthal angle of shape (...,) same : bool If True, return (R|R) = (S|S). If False, return (S|R). n_end : int Maximum degree of spherical harmonics. Returns ------- Array Initial sectorial translational coefficients of shape (..., n_end**2) """ translational_coefficients_sectorial_init_ = translational_coefficients_sectorial_init( kr, theta, phi, same, n_end ) return translational_coefficients_all( n_end=n_end, translational_coefficients_sectorial_init=translational_coefficients_sectorial_init_, )