# Copyright (c) 2022-2026, Yongchao Wu in Aalto University
# This file is from the mdapy project, released under the BSD 3-Clause License.
try:
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
except ImportError:
raise ImportError("One can install ase by pip install ase.")
from mdapy import _nepcal
from mdapy.parallel import get_num_threads
from typing import Optional, List, Tuple
import numpy as np
import os
[docs]
class NEP4ASE(Calculator):
"""
NEP calculator compatible with ASE (Atomic Simulation Environment).
This class wraps the NEP calculator to work seamlessly with ASE's
calculator interface, allowing NEP models to be used in ASE workflows
for geometry optimization, molecular dynamics, and other simulations.
Parameters
----------
model_filename : str
Path to the NEP model file
atoms : Atoms, optional
ASE Atoms object to attach the calculator to
Attributes
----------
calc : nepcal.NEPCalculator
Underlying NEP calculator object
results : dict
Dictionary storing calculation results
Examples
--------
>>> from ase import Atoms
>>> from ase.optimize import BFGS
>>> from mdapy.nep import NEP4ASE
>>>
>>> # Create atoms object
>>> atoms = Atoms("Cu2", positions=[[0, 0, 0], [1.5, 0, 0]])
>>> atoms.set_cell([10, 10, 10])
>>> atoms.set_pbc(True)
>>>
>>> # Attach NEP calculator
>>> calc = NEP4ASE("nep.txt")
>>> atoms.calc = calc
>>>
>>> # Run geometry optimization
>>> opt = BFGS(atoms)
>>> opt.run(fmax=0.01)
>>>
>>> # Get energy and forces
>>> energy = atoms.get_potential_energy()
>>> forces = atoms.get_forces()
"""
# Define which properties this calculator can compute
implemented_properties = ["energy", "energies", "forces", "stress", "virials"]
def __init__(
self,
model_filename: str,
atoms: Optional[Atoms] = None,
):
"""
Initialize NEP calculator for ASE.
Parameters
----------
model_filename : str
Path to the NEP model file
atoms : Atoms, optional
ASE Atoms object to attach to this calculator
Raises
------
FileNotFoundError
If the model file does not exist
"""
if not os.path.exists(model_filename):
raise FileNotFoundError(f"{model_filename} does not exist.")
# Load NEP model
self.calc = _nepcal.NEPCalculator(model_filename)
assert self.calc.info["model_type"] == 0, "Only support energy NEP model."
self._is_qnep = False
if self.calc.info["charge_mode"] > 0:
self._is_qnep = True
self.rc = max(self.calc.info["radial_cutoff"], self.calc.info["angular_cutoff"])
self.results = {}
# Initialize ASE Calculator base class
Calculator.__init__(self, atoms=atoms)
[docs]
def set_nep(
self, atoms: Atoms
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Prepare ASE Atoms object for NEP calculation.
Converts ASE Atoms format to the format required by NEP calculator.
Parameters
----------
atoms : Atoms
ASE Atoms object
Returns
-------
type_list : np.ndarray
Integer type indices for each atom
x : np.ndarray
X-coordinates
y : np.ndarray
Y-coordinates
z : np.ndarray
Z-coordinates
box : np.ndarray
Cell matrix (3×3)
Raises
------
AssertionError
If atoms contain elements not in the NEP model
"""
# Get chemical symbols for all atoms
symbols = atoms.get_chemical_symbols()
# Validate that all elements are supported by the model
for i in np.unique(symbols):
assert i in self.calc.info["element_list"], (
f"NEP model did not include {i}."
)
# Map symbols to type indices
type_list = np.array(
[self.calc.info["element_list"].index(i) for i in symbols], np.int32
)
# Get positions and cell
pos = np.array(atoms.get_positions())
new_box = np.array(atoms.get_cell())
for i, j in enumerate(atoms.get_pbc()):
if j == 0:
new_box[i, i] += 3 * self.rc
# Return data in NEP calculator format
return type_list, pos[:, 0], pos[:, 1], pos[:, 2], new_box
[docs]
def calculate(
self,
atoms: Atoms = None,
properties: List[str] = None,
system_changes: List[str] = all_changes,
):
"""
Perform calculation for requested properties.
This method is called by ASE when properties are requested.
It calculates the specified properties and stores them in
self.results.
Parameters
----------
atoms : Atoms, optional
ASE Atoms object (uses self.atoms if None)
properties : list of str, optional
List of properties to calculate (default: all implemented)
system_changes : list of str, optional
List of changes since last calculation (for caching)
Notes
-----
Special properties 'descriptor', 'latentspace', 'charges' and 'bec' are not in
the standard ASE interface but are available through this calculator.
"""
# Set default properties if not specified
if properties is None:
properties = self.implemented_properties
# Call parent class calculate (handles caching and validation)
Calculator.calculate(self, atoms, properties, system_changes)
N = len(atoms) # Number of atoms
# Handle special properties: descriptor and latent space
if "descriptor" in properties:
descriptor = np.zeros((N, self.calc.info["num_ndim"]), float)
self.calc.get_descriptors(*self.set_nep(atoms), descriptor, get_num_threads())
self.results["descriptor"] = descriptor
elif "latentspace" in properties:
latentspace = np.zeros((N, self.calc.info["num_nlatent"]), float)
self.calc.get_latentspace(*self.set_nep(atoms), latentspace, get_num_threads())
self.results["latentspace"] = latentspace
else:
# Standard calculation: energy, forces, stress, virials
potential = np.zeros(N, float)
force = np.zeros((N, 3), float)
virial = np.zeros((N, 9), float)
if self._is_qnep:
charge = np.zeros(N, float) # Per-atom charges
bec = np.zeros((N, 9), float) # Per-atom bec (9 components)
# Perform NEP calculation
if self._is_qnep:
self.calc.calculate_charge(
*self.set_nep(atoms), potential, force, virial, charge, bec, get_num_threads()
)
else:
self.calc.calculate(*self.set_nep(atoms), potential, force, virial, get_num_threads())
# Store results in ASE format
self.results["energy"] = potential.sum() # Total energy
self.results["energies"] = potential # Per-atom energies
self.results["forces"] = force # Forces
self.results["virials"] = virial # Per-atom virials
if self._is_qnep:
self.results["charges"] = charge
self.results["bec"] = bec
# Calculate stress tensor from virials
v = virial.sum(axis=0).reshape(3, 3)
stress = (-0.5 * (v + v.T) / atoms.get_volume()).ravel()
# Voigt notation: [σ_xx, σ_yy, σ_zz, σ_yz, σ_xz, σ_xy]
stress = stress[[0, 4, 8, 5, 2, 1]]
self.results["stress"] = stress
[docs]
def get_descriptor(
self,
atoms: Atoms = None,
system_changes: List[str] = all_changes,
) -> np.ndarray:
"""
Get atomic descriptors (not part of standard ASE interface).
Parameters
----------
atoms : Atoms, optional
ASE Atoms object
system_changes : list of str, optional
System changes since last calculation
Returns
-------
np.ndarray
Descriptor array of shape (N, num_ndim)
"""
self.calculate(atoms, ["descriptor"], system_changes)
return self.results["descriptor"]
[docs]
def get_charges(
self,
atoms: Atoms = None,
system_changes: List[str] = all_changes,
) -> np.ndarray:
"""
Get atomic charges for qNEP model (not part of standard ASE interface).
Parameters
----------
atoms : Atoms, optional
ASE Atoms object
system_changes : list of str, optional
System changes since last calculation
Returns
-------
np.ndarray
Charge array of shape (N,)
"""
if self._is_qnep:
if "charges" not in self.results.keys():
self.calculate(atoms, ["charges"], system_changes)
else:
raise ValueError("Charges is only available for qNEP.")
return self.results["charges"]
[docs]
def get_bec(
self,
atoms: Atoms = None,
system_changes: List[str] = all_changes,
) -> np.ndarray:
"""
Get atomic bec for qNEP model (not part of standard ASE interface).
Parameters
----------
atoms : Atoms, optional
ASE Atoms object
system_changes : list of str, optional
System changes since last calculation
Returns
-------
np.ndarray
Bec array of shape (N, 9)
"""
if self._is_qnep:
if "bec" not in self.results.keys():
self.calculate(atoms, ["bec"], system_changes)
else:
raise ValueError("Bec is only available for qNEP.")
return self.results["bec"]
[docs]
def get_latentspace(
self,
atoms: Atoms = None,
system_changes: List[str] = all_changes,
) -> np.ndarray:
"""
Get latent space representations (not part of standard ASE interface).
Parameters
----------
atoms : Atoms, optional
ASE Atoms object
system_changes : list of str, optional
System changes since last calculation
Returns
-------
np.ndarray
Latent space array of shape (N, num_nlatent)
"""
self.calculate(atoms, ["latentspace"], system_changes)
return self.results["latentspace"]
[docs]
def get_virials(
self,
atoms: Atoms = None,
system_changes: List[str] = all_changes,
) -> np.ndarray:
"""
Get per-atom virials (not part of standard ASE interface).
Parameters
----------
atoms : Atoms, optional
ASE Atoms object
system_changes : list of str, optional
System changes since last calculation
Returns
-------
np.ndarray
Virial array of shape (N, 9)
"""
self.calculate(atoms, ["virials"], system_changes)
return self.results["virials"]