from collections.abc import Mapping
from typing import Any, Literal, overload
import array_api_extra as xpx
from array_api._2024_12 import Array, ArrayNamespaceFull
from array_api_compat import array_namespace
from array_api_negative_index import to_symmetric
from shift_nth_row_n_steps._torch_like import create_slice
from ultrasphere import (
BranchingType,
SphericalCoordinates,
get_child,
)
from ._assume import assume_n_end_and_include_negative_m_from_harmonics
def _index_array_harmonics[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
node: TSpherical,
*,
n_end: int,
xp: ArrayNamespaceFull,
expand_dims: bool = True,
include_negative_m: bool = True,
dtype: Any = None,
device: Any = None,
) -> Array:
"""
The index of the eigenfunction corresponding to the node.
Parameters
----------
c : SphericalCoordinates[TSpherical, TCartesian]
The spherical coordinates.
node : TSpherical
The node of the spherical coordinates.
n_end : int
The maximum degree of the harmonic.
expand_dims : bool, optional
Whether to expand dimensions, by default True
include_negative_m : bool, optional
Whether to include negative m values, by default True
xp : ArrayNamespaceFull
The array namespace.
dtype : Any, optional
The dtype, by default None
device : Any, optional
The device, by default None
Returns
-------
Array
The index.
"""
branching_type = c.branching_types[node]
if branching_type == BranchingType.A and include_negative_m:
result = to_symmetric(
xp.arange(0, n_end, dtype=dtype, device=device), asymmetric=True
)
elif (
branching_type == BranchingType.B
or branching_type == BranchingType.BP
or (branching_type == BranchingType.A and not include_negative_m)
):
result = xp.arange(0, n_end, dtype=dtype, device=device)
elif branching_type == BranchingType.C:
# result = xp.arange(0, (n_end + 1) // 2)
result = xp.arange(0, n_end, dtype=dtype, device=device)
if expand_dims:
idx = c.s_nodes.index(node)
result = result[create_slice(c.s_ndim, [(idx, slice(None))], default=None)]
return result
@overload
def _index_array_harmonics_all[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
/,
*,
n_end: int,
xp: ArrayNamespaceFull,
include_negative_m: bool = ...,
expand_dims: bool = ...,
as_array: Literal[False],
mask: Literal[False] = ...,
dtype: Any = ...,
device: Any = ...,
) -> Mapping[TSpherical, Array]: ...
@overload
def _index_array_harmonics_all[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
/,
*,
n_end: int,
xp: ArrayNamespaceFull,
include_negative_m: bool = ...,
expand_dims: Literal[True] = ...,
as_array: Literal[True],
mask: bool = ...,
dtype: Any = ...,
device: Any = ...,
) -> Array: ...
def _index_array_harmonics_all[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
/,
*,
n_end: int,
xp: ArrayNamespaceFull,
include_negative_m: bool = True,
expand_dims: bool = True,
as_array: bool,
mask: bool = False,
dtype: Any = None,
device: Any = None,
) -> Array | Mapping[TSpherical, Array]:
"""
The all indices of the eigenfunction corresponding to the spherical coordinates.
Parameters
----------
c : SphericalCoordinates[TSpherical, TCartesian]
The spherical coordinates.
n_end : int
The maximum degree of the harmonic.
include_negative_m : bool, optional
Whether to include negative m values, by default True
expand_dims : bool, optional
Whether to expand dimensions, by default True
Must be True if as_array is True.
as_array : bool, optional
Whether to return as an array, by default False
mask : bool, optional
Whether to fill invalid quantum numbers with NaN, by default False
Must be False if as_array is False.
xp : ArrayNamespaceFull
The array namespace.
dtype : Any, optional
The dtype, by default None
device : Any, optional
The device, by default None
Returns
-------
Array | Mapping[TSpherical, Array]
If as_array is True, the indices of shape
[c.s_ndim,
len(index_array_harmonics(c, node1)),
...,
len(index_array_harmonics(c, node(c.s_ndim)))].
If as_array is False, the dictionary of indices.
Notes
-----
To check the indices where all quantum numbers match,
`(numbers1 == numbers2).all(axis=0)`
can be used.
Raises
------
ValueError
If expand_dims is False and as_array is True.
If mask is True and as_array is False.
"""
if not expand_dims and as_array:
raise ValueError("expand_dims must be True if as_array is True.")
if mask and not as_array:
raise ValueError("mask must be False if as_array is False.")
index_arrays = {
node: _index_array_harmonics(
c,
node,
xp=xp,
n_end=n_end,
expand_dims=expand_dims,
include_negative_m=include_negative_m,
dtype=dtype,
device=device,
)
for node in c.s_nodes
}
if as_array:
result = xp.stack(
xp.broadcast_arrays(*[index_arrays[node] for node in c.s_nodes]),
axis=0,
)
if mask:
result[
:,
~flatten_mask_harmonics(
c, n_end=n_end, xp=xp, include_negative_m=include_negative_m
),
] = xp.nan
return result
return index_arrays
def flatten_mask_harmonics[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
/,
*,
n_end: int,
xp: ArrayNamespaceFull,
include_negative_m: bool = True,
device: Any = None,
) -> Array:
"""
Create a mask representing the valid combinations of the quantum numbers.
Can be used to flatten the harmonics.
Parameters
----------
c : SphericalCoordinates[TSpherical, TCartesian]
The spherical coordinates.
n_end : int
The maximum degree of the harmonic.
include_negative_m : bool, optional
Whether to include negative m values, by default True
nodes : Iterable[TSpherical] | None, optional
The nodes to consider, by default None
If None, all nodes are considered.
xp : ArrayNamespaceFull
The array namespace.
device : Any, optional
The device, by default None
Returns
-------
Array
The mask.
Example
-------
For spherical coordinates, |m| <= n are the valid combinations.
>>> from array_api_compat import numpy as np
>>> from ultrasphere import create_spherical
>>> c = create_spherical()
>>> flatten_mask_harmonics(c, n_end=2, xp=np)
array([[ True, False, False],
[ True, True, True]])
"""
index_arrays: Mapping[TSpherical, Array] = _index_array_harmonics_all(
c,
n_end=n_end,
include_negative_m=include_negative_m,
as_array=False,
expand_dims=True,
xp=xp,
device=device,
)
mask = xp.ones((1,) * c.s_ndim, dtype=bool, device=device)
for node, branching_type in c.branching_types.items():
if branching_type == BranchingType.B:
mask = mask & (
xp.abs(index_arrays[get_child(c.G, node, "sin")]) <= index_arrays[node]
)
if branching_type == BranchingType.BP:
mask = mask & (
xp.abs(index_arrays[get_child(c.G, node, "cos")]) <= index_arrays[node]
)
if branching_type == BranchingType.C:
value = (
index_arrays[node]
- xp.abs(index_arrays[get_child(c.G, node, "sin")])
- xp.abs(index_arrays[get_child(c.G, node, "cos")])
)
mask = mask & (value % 2 == 0) & (value >= 0)
shape = xpx.broadcast_shapes(
*[index_array.shape for index_array in index_arrays.values()]
)
mask = xp.broadcast_to(mask, shape)
return mask
[docs]
def flatten_harmonics[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
harmonics: Array,
n_end: int | None = None,
include_negative_m: bool | None = None,
axis_end: int = -1,
) -> Array:
"""
Flatten the harmonics.
Parameters
----------
c : SphericalCoordinates[TSpherical, TCartesian]
The spherical coordinates.
harmonics : Array
The (unflattend) harmonics.
n_end : int | None, optional
The maximum degree of the harmonic, by default None
If None, assume from the shape of harmonics.
include_negative_m : bool | None, optional
Whether to include negative m values, by default None
If None, assume from the shape of harmonics.
axis_end : int, optional
The axis to flatten, by default -1
Must be negative.
Returns
-------
Array
The flattened harmonics of shape (..., n_harmonics).
Example
-------
>>> from array_api_compat import numpy as np
>>> from ultrasphere import create_spherical
>>> from ultrasphere_harmonics import harmonics
>>> c = create_spherical()
>>> harm = harmonics(
... c,
... {"theta": np.asarray(0.5), "phi": np.asarray(1.0)},
... n_end=2,
... phase=0,
... flatten=False,
... )
>>> np.round(harm, 2)
array([[0.28+0.j , 0. +0.j , 0. +0.j ],
[0.43+0.j , 0.09+0.14j, 0.09-0.14j]])
>>> harm_flat = flatten_harmonics(c, harm)
>>> np.round(harm_flat, 2)
array([0.28+0.j , 0.43+0.j , 0.09+0.14j, 0.09-0.14j])
"""
if axis_end >= 0:
raise ValueError("axis_end must be negative.")
xp = array_namespace(harmonics)
if n_end is None or include_negative_m is None:
n_end, include_negative_m = assume_n_end_and_include_negative_m_from_harmonics(
c,
harmonics.shape if axis_end == -1 else harmonics.shape[: axis_end + 1],
flatten=False,
)
mask = flatten_mask_harmonics(
c,
n_end=n_end,
xp=xp,
include_negative_m=include_negative_m,
device=harmonics.device,
)
shape = xpx.broadcast_shapes(harmonics.shape, mask.shape + (1,) * (-axis_end - 1))
harmonics = xp.broadcast_to(harmonics, shape)
return harmonics[(..., mask) + (slice(None),) * (-axis_end - 1)]
def unflatten_harmonics[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
harmonics: Array,
*,
include_negative_m: bool = True,
) -> Array:
"""
Unflatten the harmonics.
Parameters
----------
c : SphericalCoordinates[TSpherical, TCartesian]
The spherical coordinates.
harmonics : Array
The flattened harmonics.
include_negative_m : bool, optional
Whether to include negative m values, by default True
Returns
-------
Array
The unflattened harmonics of shape (..., n_1, n_2, ..., n_(c.s_ndim)).
Example
-------
>>> from array_api_compat import numpy as np
>>> from ultrasphere import create_spherical
>>> from ultrasphere_harmonics import harmonics
>>> c = create_spherical()
>>> harm_flat = harmonics(
... c,
... {"theta": np.asarray(0.5), "phi": np.asarray(1.0)},
... n_end=2,
... phase=0,
... )
>>> np.round(harm_flat, 2)
array([0.28+0.j , 0.43+0.j , 0.09+0.14j, 0.09-0.14j])
>>> harm = unflatten_harmonics(c, harm_flat)
>>> np.round(harm, 2)
array([[0.28+0.j , 0. +0.j , 0. +0.j ],
[0.43+0.j , 0.09+0.14j, 0.09-0.14j]])
"""
xp = array_namespace(harmonics)
n_end, _ = assume_n_end_and_include_negative_m_from_harmonics(
c, harmonics, flatten=True
)
mask = flatten_mask_harmonics(
c,
n_end=n_end,
xp=xp,
include_negative_m=include_negative_m,
device=harmonics.device,
)
shape = (*harmonics.shape[:-1], *mask.shape)
result = xp.zeros(shape, dtype=harmonics.dtype, device=harmonics.device)
result[..., mask] = harmonics
return result
[docs]
def index_array_harmonics[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
node: TSpherical,
/,
*,
n_end: int,
xp: ArrayNamespaceFull,
expand_dims: bool = True,
include_negative_m: bool = True,
flatten: bool = False,
dtype: Any = None,
device: Any = None,
) -> Array:
"""
The index of the eigenfunction corresponding to the node.
Parameters
----------
c : SphericalCoordinates[TSpherical, TCartesian]
The spherical coordinates.
node : TSpherical
The node of the spherical coordinates.
n_end : int
The maximum degree of the harmonic.
expand_dims : bool, optional
Whether to expand dimensions, by default True
include_negative_m : bool, optional
Whether to include negative m values, by default True
If None, True iff concat is True.
flatten : bool, optional
Whether to flatten the result, by default False
xp : ArrayNamespaceFull
The array namespace.
dtype : Any, optional
The dtype, by default None
device : Any, optional
The device, by default None
Returns
-------
Array
The index.
Example
-------
>>> from array_api_compat import numpy as np
>>> from ultrasphere import create_spherical
>>> c = create_spherical()
>>> index_array_harmonics(
... c,
... "theta",
... n_end=3,
... xp=np,
... )
array([[0],
[1],
[2]])
>>> index_array_harmonics(
... c,
... "phi",
... n_end=3,
... xp=np,
... )
array([[ 0, 1, 2, -2, -1]])
"""
if flatten and not expand_dims:
raise ValueError("expand_dims must be True if flatten is True.")
index_array = _index_array_harmonics(
c,
node,
n_end=n_end,
xp=xp,
expand_dims=expand_dims,
include_negative_m=include_negative_m,
dtype=dtype,
device=device,
)
if flatten:
return flatten_harmonics(
c, index_array, n_end=n_end, include_negative_m=include_negative_m
)
return index_array
@overload
def index_array_harmonics_all[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
*,
n_end: int,
xp: ArrayNamespaceFull,
include_negative_m: bool = ...,
expand_dims: bool = ...,
as_array: Literal[False],
mask: Literal[False] = ...,
flatten: bool | None = ...,
dtype: Any = ...,
device: Any = ...,
) -> Mapping[TSpherical, Array]: ...
@overload
def index_array_harmonics_all[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
*,
n_end: int,
xp: ArrayNamespaceFull,
include_negative_m: bool = ...,
expand_dims: Literal[True] = ...,
as_array: Literal[True],
mask: bool = ...,
flatten: bool | None = ...,
dtype: Any = ...,
device: Any = ...,
) -> Array: ...
[docs]
def index_array_harmonics_all[TSpherical, TCartesian](
c: SphericalCoordinates[TSpherical, TCartesian],
*,
n_end: int,
xp: ArrayNamespaceFull,
include_negative_m: bool = True,
expand_dims: bool = True,
as_array: bool,
mask: bool = False,
flatten: bool | None = None,
dtype: Any = None,
device: Any = None,
) -> Array | Mapping[TSpherical, Array]:
"""
The all indices of the eigenfunction corresponding to the spherical coordinates.
Parameters
----------
c : SphericalCoordinates[TSpherical, TCartesian]
The spherical coordinates.
n_end : int
The maximum degree of the harmonic.
include_negative_m : bool, optional
Whether to include negative m values, by default True
expand_dims : bool, optional
Whether to expand dimensions, by default True
Must be True if as_array is True.
as_array : bool, optional
Whether to return as an array, by default False
mask : bool, optional
Whether to fill invalid quantum numbers with NaN, by default False
Must be False if as_array is False.
flatten : bool, optional
Whether to flatten the result, by default None
If None, True iff as_array is True.
xp : ArrayNamespaceFull
The array namespace.
dtype : Any, optional
The dtype, by default None
device : Any, optional
The device, by default None
Returns
-------
Array | Mapping[TSpherical, Array]
If as_array is True, the indices of shape
`[c.s_ndim,
len(index_array_harmonics(c, node1)),
...,
len(index_array_harmonics(c, node(c.s_ndim)))]`.
If as_array is False, the dictionary of indices.
Notes
-----
To check the indices where all quantum numbers match,
`(numbers1 == numbers2).all(axis=0)`
can be used.
Raises
------
ValueError
If expand_dims is False and as_array is True.
If mask is True and as_array is False.
Example
-------
>>> from array_api_compat import numpy as np
>>> from ultrasphere import create_spherical
>>> c = create_spherical()
>>> index_array_harmonics_all(
... c,
... n_end=3,
... xp=np,
... as_array=True,
... )
array([[ 0, 1, 1, 1, 2, 2, 2, 2, 2],
[ 0, 0, 1, -1, 0, 1, 2, -2, -1]])
>>> index_array_harmonics_all(
... c,
... n_end=3,
... xp=np,
... as_array=False,
... )
{'theta': array([[0],
[1],
[2]]), 'phi': array([[ 0, 1, 2, -2, -1]])}
"""
if flatten is None:
flatten = as_array
if flatten and not expand_dims:
raise ValueError("expand_dims must be True if flatten is True.")
index_arrays = _index_array_harmonics_all( # type: ignore[call-overload]
c,
n_end=n_end,
xp=xp,
include_negative_m=include_negative_m,
as_array=as_array,
expand_dims=expand_dims,
mask=mask,
dtype=dtype,
device=device,
)
if flatten:
if as_array:
return flatten_harmonics(c, index_arrays)
return {
node: flatten_harmonics(c, index_array)
for node, index_array in index_arrays.items()
}
return index_arrays