Source code for mdapy.trajectory

# Copyright (c) 2022-2026, Yongchao Wu in Aalto University
# This file is from the mdapy project, released under the BSD 3-Clause License.

from typing import List, Optional, Dict, Any, Union, Iterator, Tuple, TextIO
import numpy as np
import polars as pl
import re
from mdapy.system import System
from mdapy.box import Box

# Standard column names that have special meaning in XYZ format
_STANDARD_COLS = (
    "element",
    "x",
    "y",
    "z",
    "vx",
    "vy",
    "vz",
    "fx",
    "fy",
    "fz",
    "bec_0",
    "bec_1",
    "bec_2",
    "bec_3",
    "bec_4",
    "bec_5",
    "bec_6",
    "bec_7",
    "bec_8",
)


class _TrajectoryListBase:
    """Common list-like API shared by :class:`XYZTrajectory` and
    :class:`Trajectory`. Subclasses are expected to own a
    ``self._systems: list[System]`` attribute and a ``save`` method.

    Splitting the list-API into a mixin keeps the per-format readers
    (which still live on :class:`XYZTrajectory`) free of bookkeeping
    code, and eliminates the duplicated list-method implementations
    that used to live on both classes.
    """

    _systems: List[System]

    def __len__(self) -> int:
        return len(self._systems)

    def __getitem__(self, idx):
        """Index by ``int`` (single frame), ``slice``, ``list``/``tuple`` or
        a 1-D ``numpy.ndarray`` (integer or boolean).

        - ``traj[3]`` → :class:`System` (the frame at index 3)
        - ``traj[1:4]`` → same-type container holding 3 frames
        - ``traj[[0, 5, 7]]`` → same-type container holding 3 frames
        - ``traj[np.array([True, False, True, ...])]`` → frames at the
          ``True`` positions (the mask must have length ``len(traj)``)

        Boolean masks support filtering on derived per-frame quantities,
        e.g. ``traj[traj.get_atoms_count() > 100]``.
        """
        if isinstance(idx, slice):
            # Wrap the slice in the same concrete subclass so users
            # who do `traj[:5]` get back the same class they started
            # with (XYZTrajectory or Trajectory).
            return type(self)(systems=self._systems[idx])

        # Fancy indexing: bool mask or integer index array / list / tuple.
        if isinstance(idx, (list, tuple, np.ndarray)):
            arr = np.asarray(idx)
            if arr.dtype == bool:
                if arr.shape != (len(self._systems),):
                    raise IndexError(
                        f"boolean mask must have length {len(self._systems)} "
                        f"to index a {len(self._systems)}-frame trajectory; "
                        f"got length {arr.shape[0] if arr.ndim else 'scalar'}."
                    )
                picked = [self._systems[i] for i in np.flatnonzero(arr)]
            elif np.issubdtype(arr.dtype, np.integer):
                # Allow negative indices the same way numpy does.
                n = len(self._systems)
                norm = [int(i) + n if int(i) < 0 else int(i) for i in arr]
                for i in norm:
                    if i < 0 or i >= n:
                        raise IndexError(
                            f"frame index {i} out of bounds for "
                            f"{n}-frame trajectory."
                        )
                picked = [self._systems[i] for i in norm]
            else:
                raise TypeError(
                    f"trajectory index array must be bool or integer; "
                    f"got dtype {arr.dtype}."
                )
            return type(self)(systems=picked)

        # Plain int — return the underlying System.
        return self._systems[idx]

    def __setitem__(self, idx: int, system: System) -> None:
        if not isinstance(system, System):
            raise TypeError("can only assign System instances")
        self._systems[idx] = system

    def __iter__(self) -> Iterator[System]:
        return iter(self._systems)

    def __repr__(self) -> str:
        return f"<{type(self).__name__}: {len(self)} frame(s)>"

    def append(self, system: System) -> None:
        if not isinstance(system, System):
            raise TypeError("only System instances can be appended")
        self._systems.append(system)

    def extend(self, systems: List[System]) -> None:
        for s in systems:
            self.append(s)

    def insert(self, index: int, system: System) -> None:
        if not isinstance(system, System):
            raise TypeError("only System instances can be inserted")
        self._systems.insert(index, system)

    def pop(self, index: int = -1) -> System:
        return self._systems.pop(index)

    def remove(self, indices: Union[int, List[int]]) -> None:
        if isinstance(indices, int):
            indices = [indices]
        for i in sorted(indices, reverse=True):
            self._systems.pop(i)

    def get_atoms_count(self) -> np.ndarray:
        """Per-frame atom counts as a 1-D ``int64`` numpy array.

        Returning a numpy array (rather than a Python list) lets users
        write boolean filters directly:

        >>> hot_frames = traj[traj.get_atoms_count() > 100]
        """
        return np.array([s.N for s in self._systems], dtype=np.int64)

    def concatenate(self, other: "_TrajectoryListBase") -> "_TrajectoryListBase":
        """Return a new container holding ``self`` followed by ``other``."""
        return type(self)(systems=self._systems + other._systems)


[docs] class XYZTrajectory(_TrajectoryListBase): """ XYZ trajectory file reader and manager. This class provides functionality to read, manipulate, and write XYZ format trajectory files. It supports both classical XYZ format (element + coordinates) and extended XYZ format (with periodic boundary conditions and additional properties). Two reading modes are available: - Serial mode: Read frames sequentially (default) - Fast mode: Optimized batch reading assuming all frames have identical columns Parameters ---------- filename : Optional[str] Path to XYZ trajectory file to load systems : Optional[List[System]] List of System objects to initialize trajectory from memory fast_mode : bool, default=False If True, use optimized reading assuming all frames have the same columns. This mode is significantly faster but requires all frames to have identical column structure. Raises ------ ValueError If neither filename nor systems is provided Examples -------- >>> # Load trajectory in serial mode >>> traj = XYZTrajectory("trajectory.xyz") >>> # Load trajectory in fast mode >>> traj = XYZTrajectory("trajectory.xyz", fast_mode=True) >>> # Create trajectory from existing System objects >>> traj = XYZTrajectory(systems=[system1, system2]) >>> # Access frames >>> print(len(traj)) # Number of frames >>> frame = traj[0] # Get first frame >>> sub_traj = traj[0:10] # Slice trajectory >>> # Iterate over frames >>> for frame in traj: ... print(frame.N) >>> # Save trajectory >>> traj.save("output.xyz") >>> traj.save("output.xyz", frames=[0, 1, 2]) # Save specific frames """ def __init__( self, filename: Optional[str] = None, systems: Optional[List[System]] = None, fast_mode: bool = False, verbose: bool = True, ) -> None: self._systems: List[System] = [] self._filename = filename self._fast_mode = fast_mode self._verbose = verbose if systems is not None: self._systems = systems elif filename is not None: self._load() else: raise ValueError("At least has systems or filename.") def _load(self) -> None: """Load trajectory file using the specified reading mode.""" if self._fast_mode: self._systems = self._read_xyz_fast(self._filename) else: self._systems = self._read_xyz_serial(self._filename, verbose=self._verbose) def _read_xyz_fast(self, filename: str) -> List[System]: """ Fast trajectory reading mode. Assumes all frames have identical column structure, which allows for optimized batch processing using vectorized operations. Parameters ---------- filename : str Path to XYZ file Returns ------- List[System] List of System objects for each frame """ df_raw = ( pl.read_csv( filename, has_header=False, new_columns=["line"], separator="\n", truncate_ragged_lines=True, ) .with_columns(pl.col("line").str.strip_chars()) .with_row_index() ) # Locate frame boundaries row = 0 row_list = [] while row < df_raw.shape[0]: row_list.append(row) N = int(df_raw.item(row, 1)) row += N + 2 row_list = np.array(row_list, np.int32) sele = np.r_[row_list, row_list + 1] Nframe = row_list.shape[0] # Detect format type from first frame first_comment_line = df_raw.item(row_list[0] + 1, 1) is_classical = "lattice" not in first_comment_line.lower() origin = np.zeros(3) if is_classical: # Classical XYZ format columns = ["element", "x", "y", "z"] schema = { "element": pl.Utf8, "x": pl.Float64, "y": pl.Float64, "z": pl.Float64, } boundary = np.array([0, 0, 0], np.int32) else: # Extended XYZ format - parse metadata from comment lines target_keys = [ "lattice", "properties", "pbc", "energy", "origin", "force", "virial", "stress", ] exprs = [] for key in target_keys: pat = rf'(?i){key}=(?:"([^"]+)"|([^ \n]+))' expr = ( pl.col("line") .str.extract_groups(pat) .struct.field("1") .fill_null(pl.col("line").str.extract_groups(pat).struct.field("2")) .alias(key) ) exprs.append(expr) comment_info = ( df_raw.filter(pl.col("index").is_in(row_list + 1)) .with_columns(pl.col("line").str.replace("'", '"')) .select(exprs) ) columns, schema = self._parse_properties(comment_info["properties"].item(0)) has_pbc = not comment_info["pbc"].has_nulls() has_origin = not comment_info["origin"].has_nulls() boundary = np.ones(3, np.int32) # Parse atom data for all frames rep = ( df_raw.filter(pl.col("index").is_in(row_list))["line"] .cast(pl.Int32) .to_numpy() ) frame = np.repeat(np.arange(Nframe), rep) # Get the first data line to determine the actual separator. We # match the whitespace run immediately after the first # non-whitespace token rather than using `str.find` on the # second token — `find` returns the FIRST occurrence and if the # second token happens to equal the first (e.g. `0.5 0.5 ...`) # it would mis-identify the separator as the empty string. first_data_line: str = df_raw.item(row_list[0] + 2, 1) sep_match = re.match(r"\S+(\s+)", first_data_line) separator = sep_match.group(1) if sep_match else " " all_data = ( df_raw.filter(~pl.col("index").is_in(sele)) .select( pl.col("line") .str.split_exact(separator, n=len(columns)) .struct.rename_fields(columns) .alias("_tmp") ) .unnest("_tmp") .cast(schema) .with_columns(frame=frame) .partition_by("frame", maintain_order=True, include_key=False) ) # Build System objects for each frame systems = [] for i in range(Nframe): if is_classical: coor = all_data[i].select("x", "y", "z") box = np.eye(3) * (coor.max() - coor.min()).to_numpy() new_box = Box(box=box, origin=origin, boundary=boundary) info = {} else: box = np.array(comment_info["lattice"].item(i).split(), float).reshape( 3, 3 ) if has_pbc: boundary = np.array( [ 1 if j in ("T", "1") else 0 for j in comment_info["pbc"].item(i).split() ], np.int32, ) if has_origin: origin = np.array(comment_info["origin"].item(i).split(), float) new_box = Box(box=box, origin=origin, boundary=boundary) info = {} for kk in ["energy", "force", "virial", "stress"]: va = comment_info[kk].item(i) if va is not None: info[kk] = va systems.append(System(box=new_box, data=all_data[i], global_info=info)) return systems def _read_xyz_serial(self, filename: str, verbose: bool = False) -> List[System]: """ Serial trajectory reading mode. Slower than the fast path but tolerant of multi-space separators and per-frame schema drift. Parameters ---------- filename : str Path to XYZ file. verbose : bool, default=False Print progress every 200 frames during the read. """ systems = [] frame_idx = 0 with open(filename, "r") as f: while True: natom_line = f.readline() if not natom_line or not natom_line.strip(): break natom = int(natom_line.strip()) info_line = f.readline() if not info_line: break data_lines = [] for _ in range(natom): line = f.readline() if not line: break data_lines.append(line) if len(data_lines) != natom: break df, box, global_info = self._parse_frame(info_line, data_lines) systems.append(System(data=df, box=box, global_info=global_info)) frame_idx += 1 if verbose and frame_idx % 200 == 0: # XYZ is a streaming format — we don't know the total # frame count without a full pre-scan, so print just # the running counter and refresh the same line. print( f"\r [xyz.serial] frame {frame_idx} read ", end="", flush=True, ) if verbose: print(f"\r [xyz.serial] done — {frame_idx} frames ", flush=True) return systems
[docs] def save( self, filename: str, frames: Optional[Union[List[int], int]] = None, mode: str = "w", vacuum: float = 0.0, ) -> None: """ Save trajectory to XYZ file. Parameters ---------- filename : str Output file path frames : Optional[Union[List[int], int]], default=None Frame indices to save. Can be: - None: save all frames - int: save single frame - List[int]: save specified frames mode : str, default='w' Writing mode can be: - 'w' : write mode - 'a' : append mode vacuum : float, default=0.0 When ``> 0``, every non-periodic axis of every saved frame is padded by ``vacuum`` Å (atoms shifted by ``vacuum / 2`` so they sit centred in the new cell, boundary flipped to periodic). Useful for auto-boxing classical-XYZ frames so downstream MD code sees a well-defined supercell. The in-memory trajectory is not mutated — padding is applied to a per-frame copy at write time. Examples -------- >>> traj.save("output.xyz") # save all frames >>> traj.save("output.xyz", 0) # save first frame only >>> traj.save("output.xyz", [0, 5, 10]) # save specific frames >>> traj.save("output.xyz", vacuum=200) # auto-box FFF frames """ if vacuum < 0: raise ValueError(f"vacuum must be >= 0, got {vacuum}.") if frames is None: systems_to_save = self._systems elif isinstance(frames, int): systems_to_save = [self._systems[frames]] else: systems_to_save = [self._systems[i] for i in frames] assert mode in ["w", "a"] with open(filename, mode) as f: for system in systems_to_save: if vacuum > 0: system = _pad_with_vacuum(system, vacuum) self._write_single_frame(f, system)
# NOTE: list-like API (__len__, __getitem__, __setitem__, __iter__, # __repr__, append, extend, insert, pop, remove, get_atoms_count, # concatenate) is inherited from `_TrajectoryListBase` so it stays # in sync between XYZTrajectory and Trajectory. def _parse_frame( self, info_line: str, data_lines: List[str] ) -> Tuple[pl.DataFrame, Box, Optional[Dict[str, Any]]]: """ Parse a single XYZ frame. Parses the comment line and data lines to extract atomic data, box information, and global metadata. Parameters ---------- info_line : str Comment line containing metadata data_lines : List[str] Lines containing atomic data Returns ------- Tuple[pl.DataFrame, Box, Optional[Dict[str, Any]]] DataFrame with atomic data, Box object, and global_info dictionary Raises ------ ValueError If extended XYZ format is missing required 'properties' field """ global_info = {} results = re.findall( r'(\w+)=(?:"([^"]+)"|([^ ]+))', info_line.replace("'", '"') ) for match in results: key = match[0].lower() value = match[1] if match[1] else match[2] global_info[key] = value classical = "lattice" not in global_info if not classical: if "properties" not in global_info: raise ValueError("Extended XYZ must contain 'properties'") boundary = [1, 1, 1] if "pbc" in global_info: boundary = [ 1 if i in ("T", "1") else 0 for i in global_info["pbc"].split() ] box_array = np.array(global_info["lattice"].split(), float).reshape(3, 3) origin = np.zeros(3, float) if "origin" in global_info: origin = np.array(global_info["origin"].split(), float) columns, schema = self._parse_properties(global_info["properties"]) else: boundary = [0, 0, 0] columns = ["element", "x", "y", "z"] schema = { "element": pl.Utf8, "x": pl.Float64, "y": pl.Float64, "z": pl.Float64, } origin = np.zeros(3, float) # Per-frame parsing strategy (chosen empirically; see the # benchmarks in tests/_generate_fixtures/README or the release # notes for the numbers): # * uniform single-space block → `pl.read_csv` (fastest; # ~0.4 ms / 1k atoms when the text really is single-space). # * any irregular whitespace → numpy `str.split` + per-column # `pl.Series(.astype(...))`. Hot-cache head-to-head on a # 2616×514 multi-space training set gave numpy 2.2 s vs # pure-Python dict 2.6 s; on a 7343×~150 13-column file # it was numpy 4.3 s vs dict 5.7 s. `np.loadtxt` was the # slowest (3+ ms / 1k atoms) — Python-side parser overhead. # Bulk-vectorised reading remains available via `fast_mode=True` # which amortises ALL per-frame work (regex on the comment # line, `_parse_properties`, Box construction) into one pass. from mdapy.load_save import _is_uniform_single_space import io as _io if _is_uniform_single_space(data_lines, len(columns)): buf = _io.StringIO("".join(data_lines)) df = pl.read_csv(buf, separator=" ", schema=schema, has_header=False) else: cells = np.array([row.split()[: len(columns)] for row in data_lines]) df_cols = {} for j, c in enumerate(columns): col = cells[:, j] if schema[c] == pl.Int32: df_cols[c] = pl.Series(c, col.astype(np.int32), dtype=pl.Int32) elif schema[c] == pl.Utf8: df_cols[c] = pl.Series(c, col.tolist(), dtype=pl.Utf8) else: df_cols[c] = pl.Series(c, col.astype(np.float64), dtype=pl.Float64) df = pl.DataFrame(df_cols) if classical: coor = df.select("x", "y", "z") extents = (coor.max() - coor.min()).to_numpy().flatten() # Pad zero extents (e.g. a single atom or a planar config) # so Box() can still invert the cell matrix. extents = np.where(extents > 0, extents, 1e-9) box_array = np.diag(extents) for key in ["pbc", "properties", "origin", "lattice"]: global_info.pop(key, None) return df.rechunk(), Box(box_array, boundary, origin), global_info def _parse_properties( self, properties_str: str ) -> Tuple[List[str], Dict[str, Any]]: """ Parse extended XYZ properties string. Properties format: "name:type:count:name:type:count:..." where type is S (string), R (real/float), or I (integer). Parameters ---------- properties_str : str Properties string (e.g., "species:S:1:pos:R:3:force:R:3") Returns ------- Tuple[List[str], Dict[str, Any]] Column names and Polars schema dictionary Raises ------ ValueError If property type is not recognized (not S, R, or I) Notes ----- Special property names are mapped to standard column names: - "pos" -> ["x", "y", "z"] - "species"/"element" -> "element" - "vel"/"velo" -> ["vx", "vy", "vz"] - "force"/"forces" -> ["fx", "fy", "fz"] """ content = properties_str.strip().split(":") i = 0 columns = [] schema = {} while i < len(content) - 2: n_col = int(content[i + 2]) if content[i + 1] == "S": dtype = pl.Utf8 elif content[i + 1] == "R": dtype = pl.Float64 elif content[i + 1] == "I": dtype = pl.Int32 else: raise ValueError(f"Unrecognized type {content[i + 1]}") # Magic-name mapping is applied only for the *first* # occurrence; if a second property in the same Properties # string would alias to the same canonical names (e.g. both # `force:R:3` and `forces:R:3` → fx/fy/fz), the alias is # disabled for the second one and it falls through to the # generic `<name>_<j>` path. This keeps every column unique # so the data-row split lines up with the column list. if ( content[i] == "pos" and content[i + 1] == "R" and n_col == 3 and "x" not in schema ): columns.extend(["x", "y", "z"]) for coord in ["x", "y", "z"]: schema[coord] = dtype elif ( # GPUMD writes "unwrapped_position:R:3" for trajectories # already unwrapped at the simulator side. Map it to the # LAMMPS-style ``xu/yu/zu`` triplet so downstream code # (``unwrap_trajectory``, MSD, etc.) sees a uniform column # name regardless of source. content[i] in ("unwrapped_position", "unwrapped_pos") and content[i + 1] == "R" and n_col == 3 and "xu" not in schema ): columns.extend(["xu", "yu", "zu"]) for coord in ["xu", "yu", "zu"]: schema[coord] = dtype elif ( content[i] in ["species", "element"] and content[i + 1] == "S" and n_col == 1 and "element" not in schema ): columns.append("element") schema["element"] = dtype elif ( content[i] in ["velo", "vel"] and content[i + 1] == "R" and n_col == 3 and "vx" not in schema ): columns.extend(["vx", "vy", "vz"]) for vel in ["vx", "vy", "vz"]: schema[vel] = dtype elif ( content[i] in ["force", "forces"] and content[i + 1] == "R" and n_col == 3 and "fx" not in schema ): columns.extend(["fx", "fy", "fz"]) for force in ["fx", "fy", "fz"]: schema[force] = dtype else: if n_col > 1: for j in range(n_col): col_name = f"{content[i]}_{j}" # Defensive: if the generic name *also* # collides (extremely rare but possible when # two properties share a base name), tack on # extra suffixes until unique. suffix = 0 unique_name = col_name while unique_name in schema: suffix += 1 unique_name = f"{col_name}__{suffix}" columns.append(unique_name) schema[unique_name] = dtype else: base = content[i] suffix = 0 unique_name = base while unique_name in schema: suffix += 1 unique_name = f"{base}__{suffix}" columns.append(unique_name) schema[unique_name] = dtype i += 3 return columns, schema def _write_single_frame(self, f: TextIO, system: System) -> None: """ Write a single frame to XYZ file. Writes in either classical XYZ format (no periodic boundaries) or extended XYZ format (with lattice, pbc, and properties). Parameters ---------- f : TextIO File handle opened for writing system : System System object containing frame data to write Notes ----- Classical format is used when system.box.boundary.sum() == 0, otherwise extended format is used. """ df = system.data natom = len(df) f.write(f"{natom}\n") info_parts = [] # Determine format based on boundary conditions is_extended = system.box.boundary.sum() > 0 if is_extended: # Extended XYZ format lattice = system.box.box.flatten() lattice_str = " ".join(f"{x:.10f}" for x in lattice) info_parts.append(f'Lattice="{lattice_str}"') pbc = " ".join("T" if b else "F" for b in system.box.boundary) info_parts.append(f'pbc="{pbc}"') if hasattr(system.box, "origin") and np.any(system.box.origin != 0): origin_str = " ".join(f"{x:.10f}" for x in system.box.origin) info_parts.append(f'Origin="{origin_str}"') properties = [] if "element" in df.columns: properties.append("species:S:1") properties.append("pos:R:3") if all(c in df.columns for c in ["vx", "vy", "vz"]): properties.append("vel:R:3") if all(c in df.columns for c in ["fx", "fy", "fz"]): properties.append("force:R:3") if all( c in df.columns for c in [ "bec_0", "bec_1", "bec_2", "bec_3", "bec_4", "bec_5", "bec_6", "bec_7", "bec_8", ] ): properties.append("bec:R:9") # Add other columns to properties for col in df.columns: if col not in _STANDARD_COLS: dtype = df.schema[col] if dtype == pl.Utf8: properties.append(f"{col}:S:1") elif dtype in [pl.Float32, pl.Float64]: properties.append(f"{col}:R:1") elif dtype in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]: properties.append(f"{col}:I:1") info_parts.append(f"Properties={':'.join(properties)}") # Add global_info metadata if system.global_info: for key, value in system.global_info.items(): try: value_str = str(value) if not value_str.startswith("<") and not value_str.startswith("["): if "energy" in key: info_parts.append(f"{key}={value_str}") else: info_parts.append(f'{key}="{value_str}"') except Exception: continue # Write comment line if info_parts: f.write(" ".join(info_parts) + "\n") else: f.write("\n") # Determine column write order write_columns = [] if "element" in df.columns: write_columns.append("element") write_columns.extend(["x", "y", "z"]) if is_extended: # Extended format: write additional columns if all(c in df.columns for c in ["vx", "vy", "vz"]): write_columns.extend(["vx", "vy", "vz"]) if all(c in df.columns for c in ["fx", "fy", "fz"]): write_columns.extend(["fx", "fy", "fz"]) if all( c in df.columns for c in [ "bec_0", "bec_1", "bec_2", "bec_3", "bec_4", "bec_5", "bec_6", "bec_7", "bec_8", ] ): write_columns.extend( [ "bec_0", "bec_1", "bec_2", "bec_3", "bec_4", "bec_5", "bec_6", "bec_7", "bec_8", ] ) # Add remaining custom columns for col in df.columns: if col not in _STANDARD_COLS: write_columns.append(col) # Write data rows for row in df.select(write_columns).iter_rows(): line_data = [] for val in row: if isinstance(val, float): line_data.append(f"{val:.10f}") else: line_data.append(str(val)) f.write(" ".join(line_data) + "\n")
# =========================================================================== # Module-level XYZ writers (used by both XYZTrajectory and Trajectory) # =========================================================================== def _pad_with_vacuum(system: System, vacuum: float) -> System: """Return a new :class:`System` with a vacuum buffer added along every non-periodic axis. For each axis ``i`` whose boundary is open (``box.boundary[i] == 0``) the cell is extended by ``vacuum`` Å, the atoms are shifted by ``vacuum / 2`` along ``i`` so the original cluster sits centred in the padded box, and the boundary on those axes flips to periodic. Periodic axes are left untouched. ``vacuum == 0`` is a no-op and returns the input system unchanged. Used by the trajectory writers to auto-box training-set frames that came in as classical XYZ (PBC = FFF) so downstream MD code sees a well-defined supercell. """ if vacuum < 0: raise ValueError(f"vacuum must be >= 0, got {vacuum}.") if vacuum == 0: return system boundary = list(system.box.boundary) if all(b == 1 for b in boundary): return system # nothing to pad new_box_mat = np.asarray(system.box.box, dtype=float).copy() new_origin = np.asarray(system.box.origin, dtype=float).copy() new_boundary = list(boundary) shift = np.zeros(3, dtype=float) for i in range(3): if boundary[i] == 0: new_box_mat[i, i] += vacuum shift[i] = vacuum / 2.0 new_boundary[i] = 1 new_data = system.data.with_columns( pl.col("x") + shift[0], pl.col("y") + shift[1], pl.col("z") + shift[2], ) return System( data=new_data, box=Box(new_box_mat, new_boundary, new_origin), global_info=system.global_info, ) def _write_xyz_frame_to(f: TextIO, system: System, vacuum: float = 0.0) -> None: """Write one XYZ frame to an open text file. Same logic as :meth:`XYZTrajectory._write_single_frame` but as a free function so :class:`Trajectory` can call it without instantiating XYZTrajectory. When ``vacuum > 0`` and the system has any non-periodic axis, the written frame is the result of :func:`_pad_with_vacuum` — a padded copy of the input. The original ``system`` object is not mutated. """ if vacuum > 0: system = _pad_with_vacuum(system, vacuum) df = system.data natom = len(df) f.write(f"{natom}\n") info_parts = [] is_extended = system.box.boundary.sum() > 0 if is_extended: lattice = system.box.box.flatten() lattice_str = " ".join(f"{x:.10f}" for x in lattice) info_parts.append(f'Lattice="{lattice_str}"') pbc = " ".join("T" if b else "F" for b in system.box.boundary) info_parts.append(f'pbc="{pbc}"') if hasattr(system.box, "origin") and np.any(system.box.origin != 0): origin_str = " ".join(f"{x:.10f}" for x in system.box.origin) info_parts.append(f'Origin="{origin_str}"') properties = [] if "element" in df.columns: properties.append("species:S:1") properties.append("pos:R:3") if all(c in df.columns for c in ["vx", "vy", "vz"]): properties.append("vel:R:3") if all(c in df.columns for c in ["fx", "fy", "fz"]): properties.append("force:R:3") bec_cols = [f"bec_{i}" for i in range(9)] if all(c in df.columns for c in bec_cols): properties.append("bec:R:9") for col in df.columns: if col not in _STANDARD_COLS: dtype = df.schema[col] if dtype == pl.Utf8: properties.append(f"{col}:S:1") elif dtype in (pl.Float32, pl.Float64): properties.append(f"{col}:R:1") elif dtype in (pl.Int8, pl.Int16, pl.Int32, pl.Int64): properties.append(f"{col}:I:1") info_parts.append(f"Properties={':'.join(properties)}") if system.global_info: for key, value in system.global_info.items(): try: value_str = str(value) if not value_str.startswith("<") and not value_str.startswith("["): if "energy" in key: info_parts.append(f"{key}={value_str}") else: info_parts.append(f'{key}="{value_str}"') except Exception: continue if info_parts: f.write(" ".join(info_parts) + "\n") else: f.write("\n") write_columns = [] if "element" in df.columns: write_columns.append("element") write_columns.extend(["x", "y", "z"]) if is_extended: if all(c in df.columns for c in ["vx", "vy", "vz"]): write_columns.extend(["vx", "vy", "vz"]) if all(c in df.columns for c in ["fx", "fy", "fz"]): write_columns.extend(["fx", "fy", "fz"]) bec_cols = [f"bec_{i}" for i in range(9)] if all(c in df.columns for c in bec_cols): write_columns.extend(bec_cols) for col in df.columns: if col not in _STANDARD_COLS: write_columns.append(col) for row in df.select(write_columns).iter_rows(): line_data = [] for val in row: if isinstance(val, float): line_data.append(f"{val:.10f}") else: line_data.append(str(val)) f.write(" ".join(line_data) + "\n") def _write_multi_xyz(filename: str, systems: List[System], mode: str = "w", vacuum: float = 0.0) -> None: """Write a list of System frames to an XYZ file. ``vacuum`` is forwarded to :func:`_write_xyz_frame_to` per frame, so each frame's open axes get padded independently — useful when the trajectory is a mix of classical and extended frames. """ if mode not in ("w", "a"): raise ValueError(f"mode must be 'w' or 'a', got {mode!r}") with open(filename, mode) as f: for s in systems: _write_xyz_frame_to(f, s, vacuum=vacuum) # =========================================================================== # Unified multi-frame trajectory: XYZ or LAMMPS dump (read + write) # =========================================================================== def _infer_trajectory_format(filename: str) -> str: """Infer 'xyz' vs 'dump' from the filename, accepting `.gz` suffix.""" f = filename.lower() if f.endswith(".gz"): f = f[:-3] if f.endswith(".xyz"): return "xyz" if f.endswith(".dump") or f.endswith(".lammpstrj"): return "dump" raise ValueError( f"Cannot infer trajectory format from {filename!r}; " f"pass format='xyz' or format='dump' explicitly." ) def _progress(stream_name: str, current: int, total: int, every: int = 200) -> None: """In-place progress bar used by the verbose=True trajectory loaders. Refreshes a single line via carriage-return so the terminal doesn't fill up with hundreds of progress messages. The line stays visible after completion (we end with ``\\n`` on the final tick) so the user can still tell what the file was. Layout:: [xyz.serial] [#####.....] 500/1000 (50%)\\r """ if not (current == total or (current and current % every == 0)): return pct = 100.0 * current / max(total, 1) bar_w = 30 filled = int(round(pct / 100.0 * bar_w)) bar = "#" * filled + "." * (bar_w - filled) end = "\n" if current == total else "" # Trailing space pads over any leftover characters from a wider line. print( f" [{stream_name}] [{bar}] {current}/{total} ({pct:.0f}%) ", end=end + ("\r" if not end else ""), flush=True, ) def _read_multi_dump_serial(filename: str, verbose: bool = False) -> List[System]: """Split a LAMMPS dump file into frames at every ITEM: TIMESTEP and parse each one with `BuildSystem.parse_dump_frame`. Each per-frame parser already uses ``pl.read_csv`` on uniform-space blocks (see ``_parse_dump_frame_impl``), so the dump path is already vectorised — there's no separate `fast_mode` path that would add measurable speedup, by design. """ from mdapy.load_save import _open_file, BuildSystem with _open_file(filename, "r") as fp: lines = fp.readlines() ts_idx = [i for i, l in enumerate(lines) if l.strip().startswith("ITEM: TIMESTEP")] if not ts_idx: raise ValueError(f"{filename}: no ITEM: TIMESTEP header found") boundaries = ts_idx + [len(lines)] n_frames = len(ts_idx) systems: List[System] = [] for i in range(n_frames): frame = lines[boundaries[i] : boundaries[i + 1]] df, box, info = BuildSystem.parse_dump_frame( frame, source=f"{filename}[frame {i}]" ) systems.append(System(data=df, box=box, global_info=info)) if verbose: _progress("dump.serial", i + 1, n_frames) return systems # Back-compat alias. _read_multi_dump = _read_multi_dump_serial def _write_multi_dump(filename: str, systems: List[System], mode: str) -> None: from mdapy.load_save import SaveSystem if mode not in ("w", "a"): raise ValueError(f"mode must be 'w' or 'a', got {mode!r}") with open(filename, mode + "b") as fp: for sys_obj in systems: ts = sys_obj.global_info.get("timestep", 0) if sys_obj.global_info else 0 try: ts = int(ts) except (TypeError, ValueError): ts = 0 SaveSystem.write_dump_frame_to(fp, sys_obj.box, sys_obj.data, ts)
[docs] class Trajectory(_TrajectoryListBase): """Multi-frame trajectory container — supports XYZ and LAMMPS dump (read + write). Parameters ---------- filename : str, optional Path to load (``.xyz``, ``.dump``, ``.lammpstrj``, optionally ``.gz``). systems : list of mdapy.System, optional Pre-built list of frames. format : {'xyz', 'dump'}, optional Override file-format detection. Defaults to inferring from the filename extension. fast_mode : bool, default=False Use the vectorised XYZ reader (`_read_xyz_fast`). Amortises per-frame overhead (regex-parsing the comment line, building the column schema, constructing :class:`Box`) over a single bulk pass — typically 5–7× faster on a long, regular XYZ trajectory. Requires identical column schema across frames AND single-character whitespace separators in the per-atom block; raises a ``ValueError`` naming the offending frame otherwise. ``fast_mode=True`` is **not supported for LAMMPS dump** — the dump serial reader already vectorises each frame internally (via ``pl.read_csv``), so a separate "bulk" path adds complexity without measurable speedup. Pass ``fast_mode=True`` on a dump file and you get a ``ValueError`` saying so. verbose : bool, default=True Print a one-line ``frame i/N`` progress update every 200 frames during loading. Pass ``verbose=False`` to silence the output (useful for tests / scripts). Examples -------- >>> traj = mp.Trajectory("dump.lammpstrj") >>> for frame in traj: print(frame.N) >>> traj.save("subset.xyz", frames=[0, 2, 4]) >>> traj.append(other_system) >>> # boolean / integer-array indexing for filtering frames >>> hot = traj[traj.get_atoms_count() > 100] # bool mask >>> first_few = traj[[0, 5, 7]] # integer mask >>> # silence progress output >>> traj = mp.Trajectory("big.xyz", fast_mode=True, verbose=False) """ def __init__( self, filename: Optional[str] = None, systems: Optional[List[System]] = None, format: Optional[str] = None, fast_mode: bool = False, verbose: bool = True, ) -> None: self._systems: List[System] = [] self._filename = filename self._format = format self._fast_mode = fast_mode self._verbose = verbose if systems is not None: self._systems = list(systems) elif filename is not None: self._load() else: raise ValueError("Trajectory needs either filename= or systems=") def _load(self) -> None: fmt = self._format or _infer_trajectory_format(self._filename) if fmt == "xyz": # XYZTrajectory owns the XYZ parser pipeline; we instantiate it # purely as a parser (no list-API surface; XYZTrajectory now # inherits from Trajectory itself, see below). xt = XYZTrajectory.__new__(XYZTrajectory) xt._filename = self._filename xt._fast_mode = self._fast_mode xt._verbose = self._verbose if self._fast_mode: xt._systems = xt._read_xyz_fast(self._filename) else: xt._systems = xt._read_xyz_serial(self._filename, verbose=self._verbose) self._systems = xt._systems elif fmt == "dump": if self._fast_mode: raise ValueError( "fast_mode is not supported for LAMMPS dump format. " "The dump serial reader already vectorises each " "frame internally via pl.read_csv, so a separate " "bulk path would add complexity without measurable " "speedup. Pass fast_mode=False (the default)." ) self._systems = _read_multi_dump_serial( self._filename, verbose=self._verbose ) else: raise ValueError(f"Unsupported trajectory format: {fmt!r}")
[docs] def save( self, filename: str, frames: Optional[Union[int, List[int]]] = None, mode: str = "w", format: Optional[str] = None, vacuum: float = 0.0, ) -> None: """Write the trajectory to disk. Parameters ---------- filename : str Output path. Format is auto-detected from the suffix. frames : int or list of int, optional Subset of frame indices to save. Default: all frames. mode : {'w', 'a'}, default 'w' Open mode (write/truncate vs. append). format : {'xyz', 'dump'}, optional Override format detection. vacuum : float, default=0.0 **XYZ output only.** When ``> 0``, every non-periodic axis of every saved frame is padded by ``vacuum`` Å (atoms shifted by ``vacuum / 2`` so they sit centred in the new cell, boundary flipped to periodic). Useful for auto-boxing training-set frames that came in as classical XYZ (PBC = FFF) so downstream MD code sees a well-defined supercell. The in-memory trajectory is **not** mutated — padding is applied to a per-frame copy at write time. Ignored for the dump format (LAMMPS dumps already require an explicit box). """ if vacuum < 0: raise ValueError(f"vacuum must be >= 0, got {vacuum}.") if frames is None: sel = self._systems elif isinstance(frames, int): sel = [self._systems[frames]] else: sel = [self._systems[i] for i in frames] if mode not in ("w", "a"): raise ValueError(f"mode must be 'w' or 'a', got {mode!r}") fmt = format or _infer_trajectory_format(filename) if fmt == "xyz": _write_multi_xyz(filename, sel, mode, vacuum=vacuum) elif fmt == "dump": if vacuum > 0: import warnings warnings.warn( "vacuum>0 is ignored for LAMMPS dump output (dumps " "already require a fully defined box).", UserWarning, stacklevel=2, ) _write_multi_dump(filename, sel, mode) else: raise ValueError(f"Unsupported trajectory format: {fmt!r}")
# NOTE: list-like API (__len__, __getitem__, __setitem__, __iter__, # __repr__, append, extend, insert, pop, remove, get_atoms_count, # concatenate) is inherited from `_TrajectoryListBase`.
[docs] def unwrap(self) -> "Trajectory": """Return a new :class:`Trajectory` with continuous (unwrapped) particle positions. See :func:`mdapy.unwrap_trajectory` for the full algorithm description and edge cases.""" from mdapy.unwrap_trajectory import unwrap_trajectory return unwrap_trajectory(self)
if __name__ == "__main__": traj = Trajectory("/u/22/wuy33/unix/Desktop/GAP_CN/gap_cn_training_dataset.xyz") print(traj)