Source code for mdapy.visualize

# Copyright (c) 2022-2026, Yongchao Wu in Aalto University
# This file is from the mdapy project, released under the BSD 3-Clause License.

try:
    import k3d
except ImportError:
    raise ImportError("One need install k3d: https://k3d-jupyter.org/user/index.html")

import numpy as np
import polars as pl
from k3d import Plot
from k3d.objects import Text2d
from mdapy.system import System
from mdapy.data import ele_radius, ele_dict, type_dict, struc_dict
from typing import Optional, Tuple


[docs] class View: """ Visualize atomic systems using k3d. Parameters ---------- system : System MDAPY System object containing atomic positions, box information and optional per-atom properties such as element, type, radius, etc. Attributes ---------- system : System The input atomic system. plot : Plot The k3d canvas used for visualization. atoms : k3d.points Object storing atomic coordinates, colors and radii. box : k3d.lines Object representing the simulation box. label : Text2d or None The text label for colorbar. Only created when `colored_by()` is used. """ def __init__(self, system: System): self.system = system self.label: Optional[Text2d] = None self.init_plot() def _box2lines(self) -> Tuple[np.ndarray, np.ndarray]: """ Convert simulation box into line vertices and indices. Returns ------- vertices : (8, 3) float32 ndarray Coordinates of the eight box corners. indices : (12, 2) float32 ndarray Pairs of indices defining the 12 box edges. """ vertices = np.zeros((8, 3), dtype=np.float32) origin = self.system.box.origin AB = self.system.box.box[0] AD = self.system.box.box[1] AA1 = self.system.box.box[2] vertices[0] = origin vertices[1] = origin + AB vertices[2] = origin + AB + AD vertices[3] = origin + AD vertices[4] = vertices[0] + AA1 vertices[5] = vertices[1] + AA1 vertices[6] = vertices[2] + AA1 vertices[7] = vertices[3] + AA1 indices = np.zeros((12, 2), dtype=np.float32) indices[0] = [0, 1] indices[1] = [1, 2] indices[2] = [2, 3] indices[3] = [3, 0] indices[4] = [0, 4] indices[5] = [1, 5] indices[6] = [2, 6] indices[7] = [3, 7] indices[8] = [4, 5] indices[9] = [5, 6] indices[10] = [6, 7] indices[11] = [7, 4] return vertices, indices
[docs] def colored_by_type(self) -> None: """ Color atoms using their type values. Notes ----- The type is mapped cyclically into nine predefined colors. Colors are updated in `system.data["color"]` and applied to the k3d point object if already initialized. """ self.system.update_data( self.system.data.with_columns( ((pl.col("type") - 1) % 9 + 1) .replace_strict(type_dict, return_dtype=pl.UInt32) .rechunk() .alias("color") ) ) if hasattr(self, "atoms"): self.atoms.colors = self.system.data["color"].to_numpy()
[docs] def colored_by_element(self) -> None: """ Color atoms using their element name. Notes ----- Element symbols are mapped to integer colors according to `ele_dict`. """ self.system.update_data( self.system.data.with_columns( pl.col("element") .replace_strict(ele_dict, return_dtype=pl.UInt32) .rechunk() .alias("color") ) ) if hasattr(self, "atoms"): self.atoms.colors = self.system.data["color"].to_numpy()
def _init_color(self) -> None: """ Initialize color column. Notes ----- If `color` exists, cast to uint32. If not, determine color based on element or type. """ if "color" not in self.system.data.columns: if "element" in self.system.data.columns: self.colored_by_element() else: assert "type" in self.system.data.columns self.colored_by_type() else: self.system.update_data( self.system.data.with_columns(pl.col("color").cast(pl.UInt32)) ) def _init_radius(self) -> None: """ Initialize radius column. Notes ----- If element exists, use element-specific radius. Otherwise assign a default radius of 2.0. """ if "radius" not in self.system.data.columns: if "element" in self.system.data.columns: self.system.update_data( self.system.data.with_columns( pl.col("element") .replace_strict(ele_radius, return_dtype=pl.Float32) .rechunk() .alias("radius") ) ) else: self.system.update_data( self.system.data.with_columns( pl.lit(2.0, pl.Float32).alias("radius") ) ) else: self.system.update_data( self.system.data.with_columns(pl.col("radius").cast(pl.Float32)) )
[docs] def init_plot(self) -> None: """ Initialize the visualization canvas. Notes ----- - Creates box lines, atomic points. - Initializes default color and radius. - Creates element/type legends using `text2d`. """ vertices, indices = self._box2lines() self.plot: Plot = k3d.plot(height=600) self._init_color() self._init_radius() self.box = k3d.lines( vertices, indices, color=0, indices_type="segment", width=1.5, shader="simple", group="box", ) self.atoms = k3d.points( self.system.data.select("x", "y", "z").cast(pl.Float32).to_numpy(), colors=self.system.data["color"].to_numpy(), shader="3d", point_sizes=self.system.data["radius"].to_numpy(), group="atoms", ) self.plot += self.box self.plot += self.atoms self.plot.grid_visible = False # Legend drawing if "element" in self.system.data.columns: res = self.system.data["element"].unique().sort() pos = [0.0, 0.0] for i, j in enumerate(res, start=1): pos[0] = i * 0.03 self.plot += k3d.text2d( j, position=pos, size=1.5, is_html=True, label_box=True, color=ele_dict[j], group="element", name=f"{j}", ) else: pos = [0.0, 0.0] for i, j in enumerate(self.system.data["type"].unique()): pos[0] = i * 0.05 if pos[0] > 0.45: pos[0] = (i - 10) * 0.05 pos[1] = 0.07 self.plot += k3d.text2d( f"Type {j:2}", position=pos, size=1.5, is_html=True, label_box=True, color=type_dict[(j - 1) % 9 + 1], group="type", name=f"Type {j}", )
[docs] def display(self) -> None: """ Display the k3d plot in supported environments (e.g., Jupyter). """ self.plot.display()
[docs] def close(self) -> None: """ Close the k3d plot and release the rendering canvas. """ self.plot.close()
[docs] def hide_object_by_group_name(self, name: str, remove: bool = False) -> None: """ Hide or remove k3d objects by their group name. Parameters ---------- name : str The object group to hide/remove. remove : bool, default False If True, remove the object entirely. If False, only hide it visually. """ for i in self.plot.objects: if i.group == name: if remove: self.plot -= i else: i.visible = False
[docs] def delete_color_bar(self) -> None: """ Remove existing colorbar from the plot. """ self.atoms.color_map = [] self.atoms.color_range = [] if self.label is not None: self.label.visible = False
def _colored_by_structure_type(self, method: str, show_label: bool = False) -> None: """ Color atoms based on structural classification. Parameters ---------- method : {'ptm', 'cna', 'aja', 'ids'} Column name storing structural type. show_label : bool, default False Whether to show per-structure text labels. Notes ----- Updates atom colors using predefined structure → color mapping. Clears colorbar since structure type does not require continuous scale. """ avia_method = ["ptm", "cna", "aja", "ids"] assert method in avia_method assert method in self.system.data.columns if method == "ptm": struc = { 0: "Other", 1: "FCC", 2: "HCP", 3: "BCC", 4: "ICO", 5: "Simple cubic", 6: "Cubic diamond", 7: "Hexagonal diamond", 8: "Graphene", } elif method == "cna" or method == "aja": struc = { 0: "Other", 1: "FCC", 2: "HCP", 3: "BCC", 4: "ICO", } elif method == "ids": struc = { 0: "Other", 1: "Cubic diamond", 2: "Cubic diamond (1st neighbor)", 3: "Cubic diamond (2nd neighbor)", 4: "Hexagonal diamond", 5: "Hexagonal diamond (1st neighbor)", 6: "Hexagonal diamond (2nd neighbor)", } color_struc = {i: struc_dict[struc[i]] for i in struc.keys()} self.system.update_data( self.system.data.with_columns( pl.col(method) .replace_strict(color_struc, return_dtype=pl.UInt32) .rechunk() .alias("color") ) ) number = self.system.data.group_by(method).len() number_dict = {number[i, 0]: number[i, 1] for i in range(number.shape[0])} N = self.system.N pos = [0.0, 0.0] for i in struc.keys(): pos[1] = i * 0.07 n = number_dict.get(i, 0) self.plot += k3d.text2d( f"{struc[i]} {n} {(n / N) * 100:.1f}%", position=pos, size=1.5, is_html=True, label_box=True, color=color_struc[i], group=method, name=struc[i], ) self.atoms.colors = self.system.data["color"].to_numpy() if not show_label: self.hide_object_by_group_name(method) self.delete_color_bar()
[docs] def colored_by( self, name: str, vmin: Optional[float] = None, vmax: Optional[float] = None, cmap: str = "rainbow", ) -> None: """ Color atoms based on a given scalar per-atom quantity. Parameters ---------- name : str Column name in `system.data` used for coloring. vmin : float, optional Minimum value of the colormap. If None, automatically determined. vmax : float, optional Maximum value of the colormap. If None, automatically determined. cmap : str, default "rainbow" Matplotlib colormap name. Notes ----- - If `name` is "element" or "type", discrete coloring is applied. - If `name` is a structure classifier ("ptm", "cna", "aja", "ids"), structure coloring is used. - Otherwise, continuous colormap coloring is used. - Colorbar is updated accordingly. """ assert name in self.system.data.columns # Element/type special handling if name == "element": self.colored_by_element() return elif name == "type": self.colored_by_type() return elif name in ["ptm", "cna", "aja", "ids"]: self._colored_by_structure_type(name, True) return import matplotlib as mpl # Determine range if vmin is not None and vmax is not None: assert vmin < vmax else: vmin = float(self.system.data[name].min()) vmax = float(self.system.data[name].max()) cmap_obj = mpl.colormaps[cmap] colors_rgb = np.array(cmap_obj(range(256))[:, :-1] * 255, dtype=np.uint32) delta = vmax - vmin if delta < 1e-4: # Assign middle color when no range exists r, g, b = colors_rgb[len(colors_rgb) // 2] colors = np.full(self.system.N, (r << 16) + (g << 8) + b, np.uint32) else: N = colors_rgb.shape[0] factor = (N - 1) / delta index = ( self.system.data.select( pl.when(pl.col(name) > vmax) .then(pl.lit(vmax)) .when(pl.col(name) < vmin) .then(pl.lit(vmin)) .otherwise(pl.col(name)) .alias(name) ) .select(((pl.col(name) - vmin) * factor).cast(pl.UInt32))[name] .to_numpy() ) r, g, b = colors_rgb[index].T colors = (r << 16) + (g << 8) + b self.atoms.colors = colors # Color map c_cmap = ( np.c_[np.linspace(0, 1, 256), cmap_obj(np.linspace(0, 1, 256))[:, :-1]] .flatten() .astype(np.float32) ) self.atoms.color_map = c_cmap # Color range if delta > 1e-4: self.atoms.color_range = [vmin, vmax] else: self.atoms.color_range = [vmin - 3, vmin + 3] # Colorbar label if self.label is None: self.label = k3d.text2d( name, position=(0.01, 0.5), size=2, is_html=True, label_box=False, color=0, group="colorbar", ) self.plot += self.label else: self.label.text = name self.label.visible = True # Hide structure/type/element legends for i in self.plot.objects: if i.group in ["ptm", "cna", "aja", "ids", "element", "type"]: self.hide_object_by_group_name(i.group) self.system.update_data(self.system.data.with_columns(color=colors))