# Copyright (c) 2022-2024, mushroomfire in Beijing Institute of Technology
# This file is from the mdapy project, released under the BSD 3-Clause License.
import taichi as ti
import numpy as np
try:
from .box import init_box, _pbc, _pbc_rec
from _neigh import (
build_cell_rec,
build_cell_rec_with_jishu,
build_cell_tri,
build_cell_tri_with_jishu,
)
except Exception:
from box import init_box, _pbc, _pbc_rec
from neigh._neigh import (
build_cell_rec,
build_cell_rec_with_jishu,
build_cell_tri,
build_cell_tri_with_jishu,
)
[docs]
@ti.data_oriented
class Neighbor:
"""This module is used to cerate neighbor of atoms within a given cutoff distance.
Using linked-cell list method makes fast neighbor finding possible.
Args:
pos (np.ndarray): (:math:`N_p, 3`) particles positions.
box (np.ndarray): (:math:`4, 3`) system box.
rc (float): cutoff distance.
boundary (list, optional): boundary conditions, 1 is periodic and 0 is free boundary. Defaults to [1, 1, 1].
max_neigh (int, optional): maximum neighbor number. If not given, will estimate atomatically. Default to None.
Outputs:
- **verlet_list** (np.ndarray) - (:math:`N_p, max\_neigh`) verlet_list[i, j] means j atom is a neighbor of i atom if j > -1.
- **distance_list** (np.ndarray) - (:math:`N_p, max\_neigh`) distance_list[i, j] means distance between i and j atom.
- **neighbor_number** (np.ndarray) - (:math:`N_p`) neighbor atoms number.
Examples:
>>> import mdapy as mp
>>> mp.init()
>>> import numpy as np
>>> FCC = mp.LatticeMaker(3.615, 'FCC', 10, 10, 10) # Create a FCC structure
>>> FCC.compute() # Get atom positions
>>> neigh = mp.Neighbor(FCC.pos, FCC.box,
5.) # Initialize Neighbor class.
>>> neigh.compute() # Calculate particle neighbor information.
>>> neigh.verlet_list # Check the neighbor index.
>>> neigh.distance_list # Check the neighbor distance.
>>> neigh.neighbor_number # Check the neighbor atom number.
"""
def __init__(self, pos, box, rc, boundary=[1, 1, 1], max_neigh=None):
if pos.dtype != np.float64:
self.pos = pos.astype(np.float64)
else:
self.pos = pos
self.box, self.inverse_box, self.rec = init_box(box)
self.rc = rc
self.boundary = ti.Vector([int(boundary[i]) for i in range(3)])
self.bin_length = self.rc + 0.01
self.max_neigh = max_neigh
self.box_length = ti.Vector([np.linalg.norm(self.box[i]) for i in range(3)])
self.N = self.pos.shape[0]
for i, direc in zip(range(3), ["x", "y", "z"]):
if self.boundary[i] == 1:
assert (
self.box_length[i] / 2 > self.bin_length
), f"box length along {direc} direction with periodic boundary should be two times larger than rc."
self.ncel = ti.Vector(
[
max(int(np.floor(np.linalg.norm(self.box[i]) / self.bin_length)), 3)
for i in range(3)
]
)
self.if_computed = False
@ti.kernel
def _build_verlet_list(
self,
pos: ti.types.ndarray(element_dim=1),
atom_cell_list: ti.types.ndarray(),
cell_id_list: ti.types.ndarray(),
verlet_list: ti.types.ndarray(),
distance_list: ti.types.ndarray(),
neighbor_number: ti.types.ndarray(),
box: ti.types.ndarray(element_dim=1),
inverse_box: ti.types.ndarray(element_dim=1),
):
rcsq = self.rc * self.rc
for i in range(pos.shape[0]):
nindex = 0
icel, jcel, kcel = 0, 0, 0
rij = ti.Vector([0.0, 0.0, 0.0], ti.f64)
if ti.static(self.rec):
icel, jcel, kcel = ti.floor(
(pos[i] - box[3]) / self.bin_length, dtype=int
)
else:
r = pos[i] - box[3]
n = (
r[0] * inverse_box[0]
+ r[1] * inverse_box[1]
+ r[2] * inverse_box[2]
)
icel = ti.floor((n[0] * box[0]).norm() / self.bin_length, int)
jcel = ti.floor((n[1] * box[1]).norm() / self.bin_length, int)
kcel = ti.floor((n[2] * box[2]).norm() / self.bin_length, int)
# iicel, jjcel, kkcel = icel, jcel, kcel # make sure correct cell
if icel < 0:
icel = 0
elif icel > self.ncel[0] - 1:
icel = self.ncel[0] - 1
if jcel < 0:
jcel = 0
elif jcel > self.ncel[1] - 1:
jcel = self.ncel[1] - 1
if kcel < 0:
kcel = 0
elif kcel > self.ncel[2] - 1:
kcel = self.ncel[2] - 1
for iicel in range(icel - 1, icel + 2):
for jjcel in range(jcel - 1, jcel + 2):
for kkcel in range(kcel - 1, kcel + 2):
j = cell_id_list[
iicel % self.ncel[0],
jjcel % self.ncel[1],
kkcel % self.ncel[2],
]
while j > -1:
if ti.static(self.rec):
rij = _pbc_rec(
pos[j] - pos[i], self.boundary, self.box_length
)
else:
rij = _pbc(
pos[j] - pos[i], self.boundary, box, inverse_box
)
rijdis_sq = rij[0] ** 2 + rij[1] ** 2 + rij[2] ** 2
# if i == 5:
# print(i, j, rijdis_sq)
if rijdis_sq <= rcsq and j != i:
verlet_list[i, nindex] = j
distance_list[i, nindex] = ti.sqrt(rijdis_sq)
nindex += 1
j = atom_cell_list[j]
neighbor_number[i] = nindex
[docs]
def compute(self):
"""Do the real neighbor calculation."""
atom_cell_list = np.zeros(self.N, dtype=np.int32)
cell_id_list = np.full(
(self.ncel[0], self.ncel[1], self.ncel[2]), -1, dtype=np.int32
)
need_check = True
if self.max_neigh is None:
max_neigh_list = np.zeros_like(cell_id_list)
if self.rec:
build_cell_rec_with_jishu(
self.pos,
atom_cell_list,
cell_id_list,
np.ascontiguousarray(self.box[-1]),
np.array([i for i in self.ncel]),
self.bin_length,
max_neigh_list,
)
else:
build_cell_tri_with_jishu(
self.pos,
atom_cell_list,
cell_id_list,
self.box,
self.inverse_box,
np.array([i for i in self.ncel]),
self.bin_length,
max_neigh_list,
)
self.max_neigh = (
max_neigh_list.max() * 4
) # np.partition(max_neigh_list.flatten(), -4)[-4:].sum()
need_check = False
# print(max_neigh_list, self.max_neigh)
else:
if self.rec:
build_cell_rec(
self.pos,
atom_cell_list,
cell_id_list,
np.ascontiguousarray(self.box[-1]),
np.array([i for i in self.ncel]),
self.bin_length,
)
else:
build_cell_tri(
self.pos,
atom_cell_list,
cell_id_list,
self.box,
self.inverse_box,
np.array([i for i in self.ncel]),
self.bin_length,
)
self.verlet_list = np.full((self.N, self.max_neigh), -1, dtype=np.int32)
self.distance_list = np.full(
(self.N, self.max_neigh), self.rc + 1.0, dtype=self.pos.dtype
)
self.neighbor_number = np.zeros(self.N, dtype=np.int32)
self._build_verlet_list(
self.pos,
atom_cell_list,
cell_id_list,
self.verlet_list,
self.distance_list,
self.neighbor_number,
self.box,
self.inverse_box,
)
if need_check:
max_neigh_number = self.neighbor_number.max()
assert (
max_neigh_number <= self.max_neigh
), f"Increase the max_neigh larger than {max_neigh_number}."
self._if_computed = True
@ti.kernel
def _partition_select_sort(
self, indices: ti.types.ndarray(), keys: ti.types.ndarray(), N: int
):
"""This function sorts N-th minimal value in keys.
Args:
indices (ti.types.ndarray): indices.
keys (ti.types.ndarray): values to be sorted.
N (int): number of sorted values.
"""
for i in range(indices.shape[0]):
for j in range(N):
minIndex = j
for k in range(j + 1, indices.shape[1]):
if keys[i, k] < keys[i, minIndex]:
minIndex = k
if minIndex != j:
keys[i, minIndex], keys[i, j] = keys[i, j], keys[i, minIndex]
indices[i, minIndex], indices[i, j] = (
indices[i, j],
indices[i, minIndex],
)
[docs]
def sort_verlet_by_distance(self, N: int):
"""This function sorts the first N-th verlet_list by distance_list.
Args:
N (int): number of sorted values
"""
if not self._if_computed:
self.compute()
assert (
self.neighbor_number.min() >= N
), f"N should lower than {self.neighbor_number.min()}."
self._partition_select_sort(self.verlet_list, self.distance_list, N)
if __name__ == "__main__":
from lattice_maker import LatticeMaker
from time import time
ti.init(offline_cache=True)
start = time()
lattice_constant = 4.05
x, y, z = 3, 3, 3
FCC = LatticeMaker(lattice_constant, "FCC", x, y, z)
FCC.compute()
end = time()
neigh = Neighbor(FCC.pos, FCC.box, 5.0)
neigh.compute()
print(neigh.verlet_list[0])
print(neigh.neighbor_number[0])
print(neigh.distance_list[0])
# for arch, tarch in zip(["cpu"], [ti.cpu]):
# # ti.init(
# # tarch, offline_cache=True, device_memory_fraction=0.9, default_fp=ti.f64
# # )
# start = time()
# neigh = Neighbor(FCC.pos, FCC.box, 5.0, max_neigh=50)
# neigh.compute()
# end = time()
# print(f"Arch: {arch}. Build neighbor time: {end-start} s.")
# print(neigh.verlet_list[0])
# print(neigh.neighbor_number[0])
# print(neigh.distance_list[0])
# print(np.linalg.norm(neigh.pos[5]-neigh.pos[0]))
# print(neigh.verlet_list.shape)
# print(neigh.distance_list.dtype)
# print(neigh.verlet_list.shape[1])
# print(neigh.neighbor_number.max())
# print(neigh.neighbor_number.min())
# neigh.sort_verlet_by_distance(12)
# print(neigh.verlet_list[0])
# print(neigh.distance_list[0])