Source code for ultrasphere_harmonics._core._harmonics

from collections.abc import Mapping
from typing import Literal, overload

from array_api._2024_12 import Array
from ultrasphere import BranchingType, SphericalCoordinates, get_child
from ultrasphere._coordinates import TCartesian, TSpherical

from ._concat import concat_harmonics
from ._eigenfunction import Phase, type_a, type_b, type_bdash, type_c
from ._expand_dim import expand_dims_harmonics
from ._flatten import flatten_harmonics


def _harmonics(
    c: SphericalCoordinates[TSpherical, TCartesian],
    spherical: Mapping[TSpherical, Array],
    n_end: int,
    *,
    phase: Phase,
    include_negative_m: bool = True,
    index_with_surrogate_quantum_number: bool = False,
) -> Mapping[TSpherical, Array] | Array:
    """
    Calculate the spherical harmonics.

    Parameters
    ----------
    c : SphericalCoordinates[TSpherical, TCartesian]
        The spherical coordinates.
    spherical : Mapping[TSpherical, Array]
        The spherical coordinates.
    n_end : int
        The maximum degree of the harmonic.
    phase : Phase
        Adjust phase (±) of the spherical harmonics, mainly to match conventions.
        See `Phase` for details.
    include_negative_m : bool, optional
        Whether to include negative m values, by default True
        If True, the m values are [0, 1, ..., n_end-1, -n_end+1, ..., -1],
        and starts from 0, not -n_end+1.
    index_with_surrogate_quantum_number : bool, optional
        Whether to index with surrogate quantum number, by default False

    Returns
    -------
    Array
        The spherical harmonics.

    """
    result = {}
    for node in c.s_nodes:
        value = spherical[node]
        if node == "r":
            continue
        if node not in c.s_nodes:
            raise ValueError(f"Key {node} is not in c.s_nodes {c.s_nodes}.")
        if c.branching_types[node] == BranchingType.A:
            result[node] = type_a(
                value,
                n_end=n_end,
                phase=phase,
                include_negative_m=include_negative_m,
            )
        elif c.branching_types[node] == BranchingType.B:
            result[node] = type_b(
                value,
                n_end=n_end,
                s_beta=c.S[get_child(c.G, node, "sin")],
                index_with_surrogate_quantum_number=index_with_surrogate_quantum_number,
                is_beta_type_a_and_include_negative_m=include_negative_m
                and c.branching_types[get_child(c.G, node, "sin")] == BranchingType.A,
            )
        elif c.branching_types[node] == BranchingType.BP:
            result[node] = type_bdash(
                value,
                n_end=n_end,
                s_alpha=c.S[get_child(c.G, node, "cos")],
                index_with_surrogate_quantum_number=index_with_surrogate_quantum_number,
                is_alpha_type_a_and_include_negative_m=include_negative_m
                and c.branching_types[get_child(c.G, node, "cos")] == BranchingType.A,
            )
        elif c.branching_types[node] == BranchingType.C:
            result[node] = type_c(
                value,
                n_end=n_end,
                s_alpha=c.S[get_child(c.G, node, "cos")],
                s_beta=c.S[get_child(c.G, node, "sin")],
                index_with_surrogate_quantum_number=index_with_surrogate_quantum_number,
                is_alpha_type_a_and_include_negative_m=include_negative_m
                and c.branching_types[get_child(c.G, node, "cos")] == BranchingType.A,
                is_beta_type_a_and_include_negative_m=include_negative_m
                and c.branching_types[get_child(c.G, node, "sin")] == BranchingType.A,
            )
        else:
            raise ValueError(f"Invalid branching type {c.branching_types[node]}.")
    return result


@overload
def harmonics(
    c: SphericalCoordinates[TSpherical, TCartesian],
    spherical: Mapping[TSpherical, Array],
    /,
    *,
    n_end: int,
    phase: Phase,
    include_negative_m: bool = True,
    index_with_surrogate_quantum_number: bool = False,
    expand_dims: bool = True,
    flatten: bool | None = None,
    concat: Literal[True] = ...,
) -> Array: ...


@overload
def harmonics(
    c: SphericalCoordinates[TSpherical, TCartesian],
    spherical: Mapping[TSpherical, Array],
    /,
    *,
    n_end: int,
    phase: Phase,
    include_negative_m: bool = True,
    index_with_surrogate_quantum_number: bool = False,
    expand_dims: bool = True,
    flatten: bool | None = None,
    concat: Literal[False] = ...,
) -> Mapping[TSpherical, Array]: ...


[docs] def harmonics( c: SphericalCoordinates[TSpherical, TCartesian], spherical: Mapping[TSpherical, Array], /, *, n_end: int, phase: Phase, include_negative_m: bool = True, index_with_surrogate_quantum_number: bool = False, expand_dims: bool = True, flatten: bool | None = None, concat: bool = True, ) -> Mapping[TSpherical, Array] | Array: """ Calculate the spherical harmonics. Parameters ---------- c : SphericalCoordinates[TSpherical, TCartesian] The spherical coordinates. spherical : Mapping[TSpherical, Array] The spherical coordinates. n_end : int The maximum degree of the harmonic. phase : Phase Adjust phase (±) of the spherical harmonics, mainly to match conventions. See `Phase` for details. include_negative_m : bool, optional Whether to include negative m values, by default True If True, the m values are [0, 1, ..., n_end-1, -n_end+1, ..., -1], and starts from 0, not -n_end+1. index_with_surrogate_quantum_number : bool, optional Whether to index with surrogate quantum number, by default False expand_dims : bool, optional Whether to expand dimensions so that all values of the result dictionary are commomly indexed by the same s_nodes, by default False For example, if spherical coordinates, - if True, the result will be indexed {"phi": [m], "theta": [m, n]} - if False, the result will be indexed {"phi": [m, n], "theta": [m, n]} Note that the values will not be repeated therefore the computational cost will be the same flatten : bool, optional Whether to flatten the result, by default None If None, True iff concat is True. concat : bool, optional Whether to concatenate the results, by default True Returns ------- Array The spherical harmonics. Example ------- Basic usage ^^^^^^^^^^^ >>> from array_api_compat import numpy as np >>> from ultrasphere import create_spherical >>> c = create_spherical() >>> harm = harmonics( # flattened output ... c, ... {"theta": np.asarray(0.5), "phi": np.asarray(1.0)}, ... n_end=2, ... phase=0 ... ) >>> np.round(harm, 2) array([0.28+0.j , 0.43+0.j , 0.09+0.14j, 0.09-0.14j]) Unflattened output ^^^^^^^^^^^^^^^^^^ >>> harm = harmonics( ... c, ... {"theta": np.asarray(0.5), "phi": np.asarray(1.0)}, ... n_end=2, ... phase=0, ... flatten=False, ... ) >>> np.round(harm, 2) array([[0.28+0.j , 0. +0.j , 0. +0.j ], [0.43+0.j , 0.09+0.14j, 0.09-0.14j]]) Unflattened mapping output ^^^^^^^^^^^^^^^^^^^^^^^^^^ >>> harm = harmonics( ... c, ... {"theta": np.asarray(0.5), "phi": np.asarray(1.0)}, ... n_end=2, ... phase=0, ... concat=False ... ) >>> {k: np.round(harm[k], 2) for k in c.s_nodes} {'theta': array([[0.71, 0. , 0. ], [1.07, 0.42, 0.42]]), 'phi': array([[0.4 +0.j , 0.22+0.34j, 0.22-0.34j]])} Using different phase convention to match scipy.special.sph_harm ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Call `harmonics` with `phase=3` (negative legendre + Condon-Shortley phase) >>> harm = harmonics( ... c, ... {"theta": np.asarray(0.5), "phi": np.asarray(1.0)}, ... n_end=2, ... phase=3 # negative legendre + Condon-Shortley phase ... ) >>> np.round(harm, 2) array([ 0.28+0.j , 0.43+0.j , -0.09-0.14j, 0.09-0.14j]) Call `scipy.special.sph_harm_y_all` >>> from scipy.special import sph_harm_y_all >>> harm_scipy = sph_harm_y_all(1, 1, np.asarray(0.5), np.asarray(1.0)) >>> np.round(harm_scipy, 2) array([[ 0.28+0.j , 0. +0.j , 0. +0.j ], [ 0.43+0.j , -0.09-0.14j, 0.09-0.14j]]) Flatten the scipy output to compare with `harmonics` >>> from ultrasphere_harmonics import flatten_harmonics >>> harm_scipy_flatten = flatten_harmonics(c, harm_scipy) >>> np.round(harm_scipy_flatten, 2) array([ 0.28+0.j , 0.43+0.j , -0.09-0.14j, 0.09-0.14j]) Check if they are close >>> import array_api_extra as xpx >>> np.all(xpx.isclose(harm, harm_scipy_flatten)) np.True_ """ if flatten is None: flatten = concat if index_with_surrogate_quantum_number and expand_dims: raise ValueError( "expand_dims must be False if index_with_surrogate_quantum_number is True." ) if concat and not expand_dims: raise ValueError("expand_dims must be True if concat is True.") if flatten and not expand_dims: raise ValueError("expand_dims must be True if flatten is True.") result = _harmonics( c, spherical, n_end, phase=phase, include_negative_m=include_negative_m, index_with_surrogate_quantum_number=index_with_surrogate_quantum_number, ) if expand_dims: result = expand_dims_harmonics(c, result) if concat: result = concat_harmonics(c, result) if flatten: if concat: result = flatten_harmonics(c, result) else: result = {k: flatten_harmonics(c, v) for k, v in result.items()} return result