# Copyright (c) 2022-2026, Yongchao Wu in Aalto University
# This file is from the mdapy project, released under the BSD 3-Clause License.
from mdapy import _fast_knn
from mdapy.box import Box
from mdapy.parallel import get_num_threads
import mdapy.tool_function as tool
import polars as pl
import numpy as np
# The C++ KNN kernel uses a fixed-size on-stack heap; raising this beyond
# 24 requires recompiling the extension with a larger buffer.
MAX_K = 24
[docs]
class NearestNeighbor:
"""
Perform a nearest-neighbor search for atoms within a periodic or non-periodic box.
This class computes the indices and distances of the `k` nearest neighbors
for each atom, considering periodic boundary conditions (PBC) if applicable.
For small systems, where the number of atoms is less than `k`,
the simulation box is automatically replicated according to its
boundary conditions to ensure enough neighbors are available.
Parameters
----------
data : pl.DataFrame
A Polars DataFrame containing atomic coordinates with columns
``"x"``, ``"y"``, and ``"z"``.
box : Box
The simulation box, defined as an instance of :class:`mdapy.box.Box`.
k : int
The number of nearest neighbors to search for.
Must be less than 25.
Attributes
----------
indices_py : np.ndarray
A 2D integer array of shape ``(N, k)``, storing the indices
of the nearest neighbors for each atom.
distances_py : np.ndarray
A 2D float array of shape ``(N, k)``, storing the distances
to the corresponding neighbors.
_enlarge_data : pl.DataFrame, optional
Internal replicated atomic data when periodic extension is required.
_enlarge_box : Box, optional
The enlarged simulation box corresponding to replicated atoms.
"""
def __init__(self, data: pl.DataFrame, box: Box, k: int):
for col in ("x", "y", "z"):
assert col in data.columns, f"data must contain column {col!r}."
assert data.shape[0] > 0, "data must contain at least one atom."
k = int(k)
assert 1 <= k <= MAX_K, f"k must be in [1, {MAX_K}], got {k}."
self.data = data
self.box = box
self.k = k
[docs]
def compute(self):
"""
Compute the nearest neighbors for all atoms in the system.
If the number of atoms is smaller than `k` and periodic boundaries
are enabled, the system will be automatically replicated along the
periodic directions to ensure sufficient neighbors are found.
Returns
-------
None
Results are stored in the following attributes:
- ``indices_py``: nearest neighbor indices.
- ``distances_py``: corresponding neighbor distances.
"""
data = self.data
box = self.box
repeat = self._check_repeat_nearest()
if sum(repeat) != 3:
# Small box: replicate atoms to find enough neighbors
self._enlarge_data, self._enlarge_box = tool.replicate(data, box, *repeat)
box = self._enlarge_box
data = self._enlarge_data
N = data.shape[0]
self.indices_py = np.zeros((N, self.k), np.int32)
self.distances_py = np.zeros((N, self.k), np.float64)
_fast_knn.knn(
data["x"].to_numpy(allow_copy=False),
data["y"].to_numpy(allow_copy=False),
data["z"].to_numpy(allow_copy=False),
box.box,
box.origin,
box.boundary,
self.k,
self.indices_py,
self.distances_py,
get_num_threads(),
)
def _check_repeat_nearest(self):
"""
Check and determine how many box replications are needed for KNN.
If `k` is greater than the number of atoms in the original system,
the box will be replicated along periodic directions until the
replicated system contains at least `k` atoms.
Returns
-------
repeat : list of int
The replication count along x, y, and z directions.
"""
repeat = [1, 1, 1]
N = self.data.shape[0]
if self.k > N:
assert sum(self.box.boundary) > 0, (
f"Need periodic boundary if you want to query {self.k} neighbors "
f"in {N}-atom system."
)
while np.prod(repeat) * N < self.k:
for i in range(3):
if self.box.boundary[i] == 1:
repeat[i] += 3 # a safe number
return repeat
if __name__ == "__main__":
from ovito.io import import_file
from ovito.data import NearestNeighborFinder
from mdapy import System
from time import time
filename = "test1.xyz"
k = 12
atom = import_file(filename).compute()
system = System(ovito_atom=atom)
print("atom number: ", system.N)
start = time()
finder = NearestNeighborFinder(k, atom)
ind, vec = finder.find_all()
end = time()
print("ovito time:", end - start)
start = time()
system.build_nearest_neighbor(k)
end = time()
print("mdapy time:", end - start)
res = np.linalg.norm(vec, axis=-1)
assert np.allclose(res, system.distance_list)
import freud
aq = freud.locality.AABBQuery.from_system(atom)
start = time()
query_result = aq.query(
aq.points, dict(mode="nearest", num_neighbors=k, exclude_ii=True)
)
nlist = query_result.toNeighborList()
end = time()
print("freud time:", end - start)