Source code for mdapy.visualize

# Copyright (c) 2022-2024, mushroomfire in Beijing Institute of Technology
# This file is from the mdapy project, released under the BSD 3-Clause License.

import k3d
import numpy as np
import polars as pl
import taichi as ti
import matplotlib as mpl

try:
    from .box import init_box
    from .tool_function import ele_radius, ele_dict, struc_dict, type_dict
except Exception:
    from box import init_box
    from tool_function import ele_radius, ele_dict, struc_dict, type_dict


[docs] @ti.kernel def value2color( colors_rgb: ti.types.ndarray(element_dim=1), value: ti.types.ndarray(), vmin: float, vmax: float, colors: ti.types.ndarray(), ): delta = vmax - vmin N = colors_rgb.shape[0] fac = (N - 1) / delta for i in range(value.shape[0]): val = ti.float64(value[i]) if val > vmax: val = vmax elif val < vmin: val = vmin r, g, b = colors_rgb[ti.floor((val - vmin) * fac, int)] colors[i] = (r << 16) + (g << 8) + b
[docs] class Visualize: def __init__(self, data, box) -> None: assert isinstance(data, pl.DataFrame) self.data = data self.label = None self.init_plot(*self.box2lines(box))
[docs] def box2lines(self, box): new_box, _, _ = init_box(box) vertices = np.zeros((8, 3), dtype=np.float32) origin = new_box[-1] AB = new_box[0] AD = new_box[1] AA1 = new_box[2] vertices[0] = origin vertices[1] = origin + AB vertices[2] = origin + AB + AD vertices[3] = origin + AD vertices[4] = vertices[0] + AA1 vertices[5] = vertices[1] + AA1 vertices[6] = vertices[2] + AA1 vertices[7] = vertices[3] + AA1 indices = np.zeros((12, 2), dtype=np.float32) indices[0] = [0, 1] indices[1] = [1, 2] indices[2] = [2, 3] indices[3] = [3, 0] indices[4] = [0, 4] indices[5] = [1, 5] indices[6] = [2, 6] indices[7] = [3, 7] indices[8] = [4, 5] indices[9] = [5, 6] indices[10] = [6, 7] indices[11] = [7, 4] return vertices, indices
[docs] def init_color(self): if "color" not in self.data.columns: if "type_name" in self.data.columns: self.atom_colored_by_atom_type_name() else: self.atom_colored_by_atom_type()
[docs] def init_radius(self): if "radius" not in self.data.columns: if "type_name" in self.data.columns: self.data = self.data.with_columns( pl.col("type_name") .replace(ele_radius, default=2.0, return_dtype=pl.Float32) .alias("radius") ) else: self.data = self.data.with_columns( pl.lit(2.0).cast(pl.Float32).alias("radius") )
[docs] def init_plot(self, vertices, indices): self.plot = k3d.plot(height=600) self.init_color() self.init_radius() self.box = k3d.lines( vertices, indices, color=0, indices_type="segment", width=1.5, shader="simple", group="box", ) self.atoms = k3d.points( self.data.select("x", "y", "z").to_numpy().astype(np.float32), colors=np.array(self.data["color"].to_numpy(), np.uint32), shader="3d", point_sizes=self.data["radius"].to_numpy().astype(np.float32), group="atoms", ) self.plot += self.box self.plot += self.atoms self.plot.grid_visible = False if "type_name" in self.data.columns: res = self.data.unique("type_name").sort("type").select("type_name", "type") res = {res[i, 0]: res[i, 1] for i in range(res.shape[0])} pos = [0.0, 0.0] for i, j in enumerate(self.data["type_name"].unique()): pos[0] = i * 0.03 self.plot += k3d.text2d( j, position=pos, size=1.5, is_html=True, label_box=True, color=ele_dict[j], group="type_name", name=f"{j} (Type {res[j]})", ) else: pos = [0.0, 0.0] for i, j in enumerate(self.data["type"].unique()): pos[0] = i * 0.05 if pos[0] > 0.45: pos[0] = (i - 10) * 0.05 pos[1] = 0.07 self.plot += k3d.text2d( f"Type {j:2}", position=pos, size=1.5, is_html=True, label_box=True, color=type_dict[(j - 1) % 9 + 1], group="type", name=f"Type {j}", )
[docs] def hide_object_by_group_name(self, name, remove=False): for i in self.plot.objects: if i.group == name: if remove: self.plot -= i else: i.visible = False
[docs] def show_object_by_group_name(self, name): found_name = False for i in self.plot.objects: if i.group == name: i.visible = True found_name = True if not found_name: print(f"Did not find {name}.")
[docs] def display(self): self.plot.display() self.hide_object_by_group_name("type_name") self.hide_object_by_group_name("type")
[docs] def close(self): self.plot.close()
[docs] def delete_color_bar(self): self.atoms.color_map = [] self.atoms.color_range = [] if self.label is not None: self.label.visible = False
# if self.label is not None: # self.plot -= self.label # self.label = None
[docs] def atom_colored_by_atom_type(self): self.data = self.data.with_columns( ((pl.col("type") - 1) % 9 + 1) .replace(type_dict, return_dtype=pl.UInt32) .alias("color") ) if hasattr(self, "atoms"): self.atoms.colors = np.array(self.data["color"].to_numpy(), np.uint32) self.delete_color_bar()
[docs] def atom_colored_by_atom_type_name(self): n = 1 for i in self.data["type_name"].unique(): if i not in ele_dict.keys(): ele_dict[i] = type_dict[n % 9] n += 1 self.data = self.data.with_columns( pl.col("type_name").replace(ele_dict, return_dtype=pl.UInt32).alias("color") ) if hasattr(self, "atoms"): self.atoms.colors = np.array(self.data["color"].to_numpy(), np.uint32) for i in ["ptm", "cna", "aja", "ids"]: self.hide_object_by_group_name(i) self.delete_color_bar()
[docs] def atom_colored_by_structure_type(self, method, show_label=False): avia_method = ["ptm", "cna", "aja", "ids"] assert method in avia_method assert method in self.data.columns N = self.data.shape[0] for i in avia_method: if method != i: self.hide_object_by_group_name(i) if method == "ptm": struc = { 0: "Other", 1: "FCC", 2: "HCP", 3: "BCC", 4: "ICO", 5: "Simple cubic", 6: "Cubic diamond", 7: "Hexagonal diamond", 8: "Graphene", } elif method == "cna" or method == "aja": struc = { 0: "Other", 1: "FCC", 2: "HCP", 3: "BCC", 4: "ICO", } elif method == "ids": struc = { 0: "Other", 1: "Cubic diamond", 2: "Cubic diamond (1st neighbor)", 3: "Cubic diamond (2nd neighbor)", 4: "Hexagonal diamond", 5: "Hexagonal diamond (1st neighbor)", 6: "Hexagonal diamond (2nd neighbor)", } color_struc = {i: struc_dict[struc[i]] for i in struc.keys()} self.data = self.data.with_columns( pl.col(method).replace(color_struc, return_dtype=pl.UInt32).alias("color") ) number = self.data.group_by(method).len() number_dict = {number[i, 0]: number[i, 1] for i in range(number.shape[0])} pos = [0.0, 0.0] for i in struc.keys(): pos[1] = i * 0.07 if i in number_dict.keys(): n = number_dict[i] else: n = 0 self.plot += k3d.text2d( f"{struc[i]} {n} {(n/N)*100:.1f}%", position=pos, size=1.5, is_html=True, label_box=True, color=color_struc[i], group=method, name=struc[i], ) self.atoms.colors = np.array(self.data["color"].to_numpy(), np.uint32) if not show_label: self.hide_object_by_group_name(method) self.delete_color_bar()
[docs] def atom_colored_by(self, values, vmin=None, vmax=None, cmap="rainbow"): value_name = values if isinstance(values, str): assert values in self.data.columns if values != "type_name": assert self.data[values].dtype in pl.NUMERIC_DTYPES if values == "type": self.atom_colored_by_atom_type() return elif values == "type_name": self.atom_colored_by_atom_type_name() return elif values in ["ptm", "cna", "aja", "ids"]: self.atom_colored_by_structure_type(values) return values = self.data[values].to_numpy() else: assert values.shape[0] == self.data.shape[0] if vmin is not None and vmax is not None: assert vmin < vmax else: vmin, vmax = float(values.min()), float(values.max()) cmap = mpl.colormaps[cmap] colors_rgb = np.array(cmap(range(256))[:, :-1] * 255, dtype=int) colors = np.zeros(values.shape[0], dtype=int) if vmax - vmin > 1e-4: value2color(colors_rgb, values, vmin, vmax, colors) else: r, g, b = colors_rgb[int(len(colors_rgb) / 2)] colors += (r << 16) + (g << 8) + b colors = colors.astype(np.uint32) self.atoms.colors = colors c_cmap = ( np.c_[np.linspace(0, 1, 256), cmap(np.linspace(0, 1, 256))[:, :-1]] .flatten() .astype(np.float32) ) self.atoms.color_map = c_cmap if vmax - vmin > 1e-4: self.atoms.color_range = [vmin, vmax] else: self.atoms.color_range = [vmin - 5, vmin + 5] if isinstance(value_name, str): if self.label is None: self.label = k3d.text2d( value_name, position=(0.01, 0.5), size=2, is_html=True, label_box=False, color=0, group="colorbar", ) self.plot += self.label else: self.label.text = value_name self.label.visible = True self.data = self.data.with_columns(pl.lit(colors).alias("colors"))