# Copyright (c) 2022-2026, Yongchao Wu in Aalto University
# This file is from the mdapy project, released under the BSD 3-Clause License.
try:
from lammps import lammps
except ImportError:
raise ImportError(
"One can install lammps python package: https://docs.lammps.org/Python_install.html"
)
import os
import sys
if sys.platform == "darwin":
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
from mdapy.calculator import CalculatorMP
from mdapy.box import Box
import numpy as np
import polars as pl
from typing import List, Any, Optional
import contextlib
[docs]
@contextlib.contextmanager
def silence(enabled: bool = True):
"""Redirect stdout/stderr to /dev/null when ``enabled`` is True; otherwise no-op."""
if not enabled:
yield
return
devnull = os.open(os.devnull, os.O_WRONLY)
old_stdout = os.dup(1)
old_stderr = os.dup(2)
try:
os.dup2(devnull, 1)
os.dup2(devnull, 2)
yield
finally:
os.dup2(old_stdout, 1)
os.dup2(old_stderr, 2)
os.close(devnull)
os.close(old_stdout)
os.close(old_stderr)
[docs]
class LammpsPotential(CalculatorMP):
"""
LAMMPS-based calculator that runs a single-point evaluation to obtain
per-atom energies, forces, virials and the global stress.
Parameters
----------
pair_parameter : str
The LAMMPS pair style / pair coeff commands as a single string.
This string is passed directly to LAMMPS with `commands_string`.
element_list : List[str]
List of element names supported by this potential. The index in this
list defines the corresponding LAMMPS atom type (1-based).
units : str, optional
Units for LAMMPS (default ``"metal"``). Currently the code asserts
that units == "metal".
centroid_stress : bool, optional
If True, uses `compute centroid/stress/atom NULL` in LAMMPS;
otherwise uses `compute stress/atom NULL`.
cmdargs : list of str, optional
Extra command-line arguments forwarded to ``lammps(cmdargs=...)``.
Useful for accelerator packages, e.g. Kokkos:
``["-k", "on", "g", "1", "-sf", "kk", "-pk", "kokkos", "newton", "on", "neigh", "half"]``.
These are appended after the default
``["-echo", "none", "-log", "none", "-screen", "none"]``.
extra_commands : str, optional
Extra LAMMPS commands executed right after the box/atoms are set up
and before ``pair_parameter``. Use this for things that must come
before the pair style (e.g. ``"package kokkos newton on neigh half"``,
``"newton on"``, ``"atom_modify ..."`` overrides handled in-script).
silence_lammps : bool, optional
If True (default), redirect LAMMPS stdout/stderr to /dev/null while
the calculator runs. Set to False to see LAMMPS output for debugging.
"""
def __init__(
self,
pair_parameter: str,
element_list: List[str],
units: str = "metal",
centroid_stress: bool = False,
cmdargs: Optional[List[str]] = None,
extra_commands: Optional[str] = None,
silence_lammps: bool = True,
) -> None:
self.pair_parameter = pair_parameter
self.element_list = element_list
self.units = units
assert units == "metal", "Only support metal units now."
self.centroid_stress = centroid_stress
self.cmdargs = list(cmdargs) if cmdargs else []
self.extra_commands = extra_commands
self.silence_lammps = silence_lammps
[docs]
def calculate(self, data: pl.DataFrame, box: Box) -> None:
"""
Run LAMMPS to calculate per-atom energies, forces and virials and
compute global stress.
This function validates inputs, constructs a triclinic LAMMPS box,
creates atoms, sets up computes, runs `run 0`, extracts LAMMPS
computed quantities, converts units, reorders virials, and stores
results in ``self.results``.
Parameters
----------
data : polars.DataFrame
Polars DataFrame with required columns: "x", "y", "z", "element".
box : Box
Box object from mdapy.
Notes
-----
- The method relies on `lammps` Python bindings to exist and provide:
- `lammps(cmdargs=...)`, `commands_string`, `create_atoms`,
- `numpy.extract_atom(...)`, `numpy.extract_compute(...)`,
- `numpy.extract_atom("f")`, and `.close()`.
- Virial unit conversion: `virial = virial / 1e4 / 160.21766208`
(converts LAMMPS reported units to eV).
- The final global stress is computed as:
stress = -(virial_tensor + virial_tensor.T) / (2 * box.volume)
and returned in Voigt order [xx, yy, zz, yz, xz, xy].
"""
for i in ["x", "y", "z", "element"]:
assert i in data.columns, f"data does not have {i} information."
for i in data["element"].unique():
assert i in self.element_list, f"element_list dose not have {i} element."
boundary = " ".join(["p" if i == 1 else "s" for i in box.boundary])
N_atom = data.shape[0]
with silence(self.silence_lammps):
base_cmdargs = ["-echo", "none", "-log", "none", "-screen", "none"]
lmp = lammps(cmdargs=base_cmdargs + self.cmdargs)
try:
lmp.commands_string(f"units {self.units}")
lmp.commands_string(f"boundary {boundary}")
lmp.commands_string("atom_style atomic")
num_type = len(self.element_list)
create_box = f"""lattice custom 1.0 a1 {box.box[0, 0]} {box.box[0, 1]} {box.box[0, 2]} a2 {box.box[1, 0]} {box.box[1, 1]} {box.box[1, 2]} a3 {box.box[2, 0]} {box.box[2, 1]} {box.box[2, 2]} basis 0.0 0.0 0.0 triclinic/general
create_box {num_type} NULL 0 1 0 1 0 1"""
lmp.commands_string(create_box)
ele2type = {j: i + 1 for i, j in enumerate(self.element_list)}
type_list = data.select(
pl.col("element")
.replace_strict(ele2type, return_dtype=pl.Int32)
.rechunk()
.alias("type")
)["type"].to_numpy(allow_copy=False)
id_list = np.arange(1, N_atom + 1)
if box.is_general_box():
box_lmp, rotate = box.align_to_lammps_box()
R = box.inverse_box @ box_lmp.box
x_list = (
(data.select("x", "y", "z").to_numpy() - box.origin) @ rotate
).flatten()
else:
x_list = (
data.select(
pl.col("x") - box.origin[0],
pl.col("y") - box.origin[1],
pl.col("z") - box.origin[2],
)
.to_numpy()
.flatten()
)
N_lmp = lmp.create_atoms(N_atom, id_list, type_list, x_list)
assert N_atom == N_lmp, "Create atoms incorrectly."
for i in range(num_type):
lmp.commands_string(f"mass {i + 1} 1.0")
if self.centroid_stress:
lmp.commands_string("compute 1 all centroid/stress/atom NULL")
else:
lmp.commands_string("compute 1 all stress/atom NULL")
lmp.commands_string("compute 2 all pe/atom")
if self.extra_commands:
lmp.commands_string(self.extra_commands)
lmp.commands_string(self.pair_parameter)
# lmp.commands_string(
# "dump 1 all custom 1 out.dump id type x y z c_2 fx fy fz c_1[*]"
# )
# lmp.commands_string(
# "dump_modify 1 triclinic/general yes sort id element Cr Co Ni"
# )
lmp.commands_string("run 0")
sort_index = np.argsort(lmp.numpy.extract_atom("id")[:N_atom])
energy = np.asarray(
lmp.numpy.extract_compute("2", 1, 1)[:N_atom][sort_index]
)
force = np.asarray(lmp.numpy.extract_atom("f")[:N_atom][sort_index])
if box.is_general_box():
force = force @ rotate.T
# xx, yy, zz, xy, xz, yz, yx, zx, zy.
# xx, yy, zz, xy, xz, yz
virial = -np.asarray(
lmp.numpy.extract_compute("1", 1, 2)[:N_atom][sort_index]
)
# v_xx, v_xy, v_xz, v_yx, v_yy, v_yz, v_zx, v_zy, v_zz
virial = self._reorder_virial(virial)
if box.is_general_box():
virial = (R @ virial.reshape((N_atom, 3, 3)) @ R.T).reshape(
(N_atom, 9)
)
# Some potentials can not compute per-atom virial, such as mtp.
if box.is_general_box():
stress = np.array(
[
lmp.get_thermo(p)
for p in (
"pxx",
"pxy",
"pxz",
"pxy",
"pyy",
"pyz",
"pxz",
"pyz",
"pzz",
)
]
).reshape(3, 3)
stress = R @ stress @ R.T
stress = np.array(
[
stress[0, 0],
stress[1, 1],
stress[2, 2],
stress[1, 2],
stress[0, 2],
stress[0, 1],
]
)
else:
stress = np.array(
[
lmp.get_thermo(p)
for p in ("pxx", "pyy", "pzz", "pyz", "pxz", "pxy")
]
)
except Exception as e:
raise e
finally:
lmp.close()
self.results["stress"] = -stress / 1e4 / 160.21766208 # bar to eV/A^3
self.results["energies"] = energy
self.results["forces"] = force
self.results["virials"] = virial / 1e4 / 160.21766208 # bar to eV
def _reorder_virial(self, v: np.ndarray) -> np.ndarray:
"""
Reorder virial array into a 9-component per-atom format.
Parameters
----------
v : numpy.ndarray
Input virial array with shape (N, 9) or (N, 6). The code expects
LAMMPS-style ordering. If shape is (N,9), the function unpacks as:
xx, yy, zz, xy, xz, yz, yx, zx, zy.
If shape is (N,6), it's treated as symmetric: xx, yy, zz, xy, xz, yz.
Returns
-------
numpy.ndarray
Array of shape (N, 9) with columns ordered as:
[xx, xy, xz, yx, yy, yz, zx, zy, zz]
"""
if v.shape[1] == 9:
xx, yy, zz, xy, xz, yz, yx, zx, zy = v.T
elif v.shape[1] == 6:
# symmetric case
xx, yy, zz, xy, xz, yz = v.T
yx, zx, zy = xy, xz, yz
else:
raise ValueError("Input must have shape (N,9) or (N,6)")
out = np.column_stack([xx, xy, xz, yx, yy, yz, zx, zy, zz])
return out
[docs]
def get_energies(self, data: pl.DataFrame, box: Box) -> Any:
"""
Return per-atom energies. If not already computed, triggers calculate().
Parameters
----------
data : polars.DataFrame
box : Box
Returns
-------
Any
Stored per-atom energies (as placed into self.results["energies"]).
"""
if "energies" not in self.results.keys():
self.calculate(data, box)
return self.results["energies"]
[docs]
def get_energy(self, data: pl.DataFrame, box: Box) -> Any:
"""
Return total energy (sum of per-atom energies).
Parameters
----------
data : polars.DataFrame
box : Box
Returns
-------
Any
Sum of per-atom energies.
"""
return self.get_energies(data, box).sum()
[docs]
def get_forces(self, data: pl.DataFrame, box: Box) -> Any:
"""
Return per-atom forces; compute if necessary.
"""
if "forces" not in self.results.keys():
self.calculate(data, box)
return self.results["forces"]
[docs]
def get_stress(self, data: pl.DataFrame, box: Box) -> Any:
"""
Return global stress in Voigt order [xx, yy, zz, yz, xz, xy]; compute if necessary.
"""
if "stress" not in self.results.keys():
self.calculate(data, box)
return self.results["stress"]
[docs]
def get_virials(self, data: pl.DataFrame, box: Box) -> Any:
"""
Return per-atom virials (9 components) and compute if necessary.
"""
if "virials" not in self.results.keys():
self.calculate(data, box)
return self.results["virials"]
if __name__ == "__main__":
from mdapy import build_crystal, EAM
ni = build_crystal("Ni", "fcc", 3.53, nx=3, ny=3, nz=3)
pot = LammpsPotential(
pair_parameter="""pair_style eam/alloy/kk
pair_coeff * * /Users/herrwu/mypkg/mdapy/tests/input_files/NiCoCr.lammps.eam Ni""",
element_list=["Ni"],
cmdargs=[
"-k",
"on",
"-sf",
"kk",
"-pk",
"kokkos",
],
)
ni.calc = pot
e_kokkos = ni.get_energy()
eam = EAM("/Users/herrwu/mypkg/mdapy/tests/input_files/NiCoCr.lammps.eam")
ni.calc = eam
e_eam = ni.get_energy()
print(e_kokkos, e_eam)