# Copyright (c) 2022-2026, Yongchao Wu in Aalto University
# This file is from the mdapy project, released under the BSD 3-Clause License.
from mdapy import _spline
import numpy as np
from typing import Union, Tuple, List, Optional
[docs]
class Spline:
"""Cubic spline interpolation on a strictly-increasing grid.
Constructs a piecewise-cubic :math:`s(x)` that is :math:`C^2` over the whole
range, reproduces the sample points :math:`(x_i, y_i)` exactly, and satisfies
the chosen boundary condition at the two endpoints. The grid need not be
uniform — if it is, prefer the internal ``UniformCubicSpline`` (used by the
EAM code, not exposed here) which is O(1) per lookup.
Parameters
----------
x : array_like
1-D array of x-coordinates. Must be strictly increasing and contain at
least two points.
y : array_like
1-D array of y-coordinates, same length as ``x``.
bc_type : {"not-a-knot", "natural", "clamped"}, default "not-a-knot"
Boundary condition at the two endpoints:
- ``"not-a-knot"`` — the third derivative is continuous at ``x[1]`` and
``x[n-2]`` (equivalently, the first two and last two cubic pieces are
each a single polynomial). Same default as
``scipy.interpolate.CubicSpline``. Best for general data when the
endpoint slopes are unknown.
- ``"natural"`` — :math:`s''(x_0) = s''(x_{n-1}) = 0`. Produces a
minimum-curvature interpolant that flattens out at the ends.
- ``"clamped"`` — :math:`s'(x_0) = \\texttt{dy0}`,
:math:`s'(x_{n-1}) = \\texttt{dyn}`. If ``dy0`` and ``dyn`` are not
given, they are estimated by fitting a quadratic through the first
(last) three points and taking its analytic derivative at the
endpoint.
dy0, dyn : float, optional
Endpoint first derivatives, only used when ``bc_type="clamped"``. Both
must be provided together; if either is ``None`` the three-point
estimates are used.
Notes
-----
Evaluation is O(log n) per point via binary search. Batch evaluation is
OpenMP-parallelised.
Out-of-range queries raise ``IndexError`` for scalar calls and return
``NaN`` element-wise for array calls. There is deliberately no silent
extrapolation — cubic extrapolation past the last knot can swing wildly
on smooth-looking data (see the EAM rho-clamping incident for a live
example).
Examples
--------
>>> import numpy as np
>>> x = np.linspace(0, 2 * np.pi, 13)
>>> y = np.sin(x)
>>> sp = Spline(x, y) # default: not-a-knot
>>> abs(sp.evaluate(np.pi / 4) - np.sin(np.pi / 4)) < 1e-4
True
>>> sp.derivative(0.0) # should be ~cos(0) = 1
1.0000... # doctest: +SKIP
A clamped spline with user-supplied endpoint slopes — useful when you know
the analytic derivative at the ends (here we know :math:`\\cos(0) = 1` and
:math:`\\cos(2\\pi) = 1`):
>>> sp_c = Spline(x, y, bc_type="clamped", dy0=1.0, dyn=1.0)
The natural spline, by contrast, forces :math:`s'' = 0` at the ends, which
is appropriate when you expect the data to flatten beyond the sample range.
"""
_BC_MAP = {
"not-a-knot": _spline.BCType.NotAKnot,
"natural": _spline.BCType.Natural,
"clamped": _spline.BCType.Clamped,
}
def __init__(
self,
x: Union[List, Tuple, np.ndarray],
y: Union[List, Tuple, np.ndarray],
bc_type: str = "not-a-knot",
dy0: Optional[float] = None,
dyn: Optional[float] = None,
):
self.x, self.y = self._validate(x, y)
self.bc_type = bc_type
if bc_type not in self._BC_MAP:
raise ValueError(
f"Unknown bc_type {bc_type!r}. "
f"Expected one of {list(self._BC_MAP)}."
)
if bc_type == "clamped" and (dy0 is not None or dyn is not None):
if dy0 is None or dyn is None:
raise ValueError(
"For clamped with explicit derivatives both dy0 and dyn "
"must be given."
)
self._sp = _spline.CubicSpline(self.x, self.y, float(dy0), float(dyn))
else:
self._sp = _spline.CubicSpline(self.x, self.y, self._BC_MAP[bc_type])
# ------------------------------------------------------------------
# Evaluation helpers
# ------------------------------------------------------------------
[docs]
def evaluate(
self, x: Union[float, int, List, Tuple, np.ndarray]
) -> Union[float, np.ndarray]:
"""Evaluate :math:`s(x)` at scalar or array ``x``.
Array inputs return an ``np.ndarray`` of the same length; entries
outside the interpolation range become ``NaN``. Scalar inputs raise
``IndexError`` if out of range.
"""
return self._call(self._sp.evaluate, x, "value")
[docs]
def derivative(
self, x: Union[float, int, List, Tuple, np.ndarray]
) -> Union[float, np.ndarray]:
"""Evaluate :math:`s'(x)` at scalar or array ``x``.
The derivative is computed analytically from the stored cubic
coefficients, not by finite differencing.
"""
return self._call(self._sp.derivative, x, "derivative")
[docs]
def second_derivative(
self, x: Union[float, int, List, Tuple, np.ndarray]
) -> Union[float, np.ndarray]:
"""Evaluate :math:`s''(x)` at scalar or array ``x``.
:math:`s''` is piecewise-linear between the knots (a property of
cubic splines), so this is exact up to floating-point rounding.
"""
return self._call(self._sp.second_derivative, x, "second derivative")
# Convenience: ``sp(x)`` and ``sp.evaluate(x)`` do the same thing.
__call__ = evaluate
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _call(self, backend, x, kind):
if isinstance(x, (int, float, np.integer, np.floating)):
xf = float(x)
if xf < self.x[0] or xf > self.x[-1]:
raise IndexError(
f"Cannot evaluate {kind} at x={xf}: outside interpolation "
f"range [{self.x[0]}, {self.x[-1]}]."
)
return backend(xf)
if isinstance(x, np.ndarray):
x_arr = x if x.dtype == np.float64 else x.astype(np.float64)
elif isinstance(x, (list, tuple)):
x_arr = np.asarray(x, dtype=np.float64)
else:
raise TypeError(
f"Input type {type(x)} not supported. "
"Expected float, int, list, tuple, or numpy.ndarray."
)
return backend(x_arr)
@staticmethod
def _validate(x, y):
x = np.asarray(x, dtype=np.float64)
y = np.asarray(y, dtype=np.float64)
if x.ndim != 1:
raise ValueError(f"x must be 1-dimensional, got {x.ndim}D array")
if y.ndim != 1:
raise ValueError(f"y must be 1-dimensional, got {y.ndim}D array")
if len(x) < 2:
raise ValueError(f"x must have at least 2 points, got {len(x)}")
if len(x) != len(y):
raise ValueError(
f"Length of x and y must match. Got x: {len(x)}, y: {len(y)}"
)
return x, y
if __name__ == "__main__":
pass