Source code for mdapy.phonon

# Copyright (c) 2022, mushroomfire in Beijing Institute of Technology
# This file is from the mdapy project, released under the BSD 3-Clause License.

import matplotlib.pyplot as plt
import numpy as np

try:
    from .plotset import set_figure
except Exception:
    from plotset import set_figure


[docs] class Phonon: """This class is used to plot the phonon dispersion based on the band data generated from `phonopy <https://phonopy.github.io/phonopy/>`_. Args: filename (str): the filename of band data, such as band.dat, which can be generated by: phonopy-bandplot --gnuplot band.yaml>band.dat """ def __init__(self, filename) -> None: self.filename = filename self.data, self.kpoints = self._get_band_data() def _get_band_data(self): with open(self.filename) as op: data = op.readlines() kpoints = np.array(data[1].split()[1:], float) sepa = [1] for i, j in enumerate(data): if j == "\n": sepa.append(i) res = {} num = 0 for i in range(len(sepa) - 1): pot = np.array([k.split() for k in data[sepa[i] + 1 : sepa[i + 1]]], float) if len(pot) > 0: res[f"{num}"] = pot num += 1 return res, kpoints def _rgb2hex(self, rgb): color = "#" for i in rgb: color += str(hex(i))[-2:].replace("x", "0").upper() return color
[docs] @staticmethod def get_supercell(rx, ry, rz, inputfile="cp2k.inp", tolerance=1e-3, backend="cp2k"): if backend == "cp2k": return f'phonopy --cp2k -c {inputfile} -d --dim="{rx} {ry} {rz}" --tolerance {tolerance}' elif backend == "vasp": return f'phonopy -d --dim="{rx} {ry} {rz}" --tolerance {tolerance}'
[docs] def plot_dispersion( self, units="THz", kpoints_label=None, yticks=None, color=None, merge_kpoints=None, ): """This function can plot the phonon dispersion. Args: units (str, optional): units of frequency, selected in ['THz', '1/cm']. Defaults to "THz". kpoints_label (list[str], optional): kpoints label, such as ["$\Gamma$", "K", "M", "$\Gamma$"] for graphene. Defaults to None. yticks (list[float], optional): y axis ticks, such as [0, 10, 20]. Defaults to None. color (str | rgb turple, optional): line color, can be a str, such as 'r', '#729CBD', or a rgb turple, such as [125, 125, 125]. Defaults to None. merge_kpoints (list, optional): sometimes you want to merge two equalvalue points, such as [2, 3]. Defaults to None. Returns: tuple: (fig, ax) matplotlib figure and axis class. """ fig, ax = set_figure( figsize=(10, 7.5), bottom=0.08, left=0.16, use_pltset=True, figdpi=200 ) if color is None: color = "b" elif isinstance(color, str): color = color else: assert ( len(color) == 3 ), "Only support str or a three-elements rgb turple, such as [125, 125, 125]." color = self._rgb2hex(color) if merge_kpoints is None: for i in self.data.keys(): x, y = self.data[i][:, 0], self.data[i][:, 1] if units == "1/cm": y *= 33.4 ax.plot(x, y, lw=1.2, c=color) ax.plot( [self.kpoints[0], self.kpoints[-1]], [0, 0], "--", c="grey", lw=1.0, alpha=0.5, ) ax.set_xlim(self.kpoints[0], self.kpoints[-1]) ax.set_xticks(self.kpoints) else: assert len(merge_kpoints) == 2 assert min(merge_kpoints) >= 0 assert max(merge_kpoints) <= len(self.kpoints) - 1 L, R = merge_kpoints assert L < R move = self.kpoints[R] - self.kpoints[L] for i in self.data.keys(): x, y = self.data[i][:, 0], self.data[i][:, 1] if units == "1/cm": y *= 33.4 if ( x.min() >= self.kpoints[L] - 0.01 and x.max() <= self.kpoints[R] + 0.01 ): pass else: if x.min() >= self.kpoints[R] - 0.01: ax.plot(x - move, y, lw=1.2, c=color) else: ax.plot(x, y, lw=1.2, c=color) ax.plot( [self.kpoints[0], self.kpoints[-1] - move], [0, 0], "--", c="grey", lw=1.0, alpha=0.5, ) ax.set_xlim(self.kpoints[0], self.kpoints[-1] - move) if R == len(self.kpoints) - 1: ax.set_xticks(self.kpoints[: L + 1]) else: ax.set_xticks( np.hstack( [ self.kpoints[: L + 1], (self.kpoints[min(R + 1, len(self.kpoints) - 1) :] - move), ] ) ) ax.set_xticks([], minor=True) ax.set_yticks([], minor=True) if yticks is not None: ax.set_yticks(yticks) if kpoints_label is not None: if len(kpoints_label) == len(self.kpoints): if merge_kpoints is None: ax.set_xticklabels(kpoints_label) else: if R == len(self.kpoints) - 1: ax.set_xticklabels( np.hstack( [ kpoints_label[:L], [kpoints_label[L] + "$|$" + kpoints_label[R]], ] ) ) else: ax.set_xticklabels( np.hstack( [ kpoints_label[:L], [kpoints_label[L] + "$|$" + kpoints_label[R]], kpoints_label[min(R + 1, len(kpoints_label) - 1) :], ] ) ) if units == "1/cm": ax.set_ylabel("Frequency ($cm^{-1}$)") else: ax.set_ylabel("Frequency (THz)") ylo, yhi = ax.get_ylim() xticks = ax.get_xticks() for i in xticks: ax.plot( [i, i], [ylo, yhi], "--", lw=0.8, c="grey", alpha=0.5, ) ax.set_ylim(ylo, yhi) plt.show() return fig, ax
if __name__ == "__main__": pho = Phonon(r"D:\Study\Gra-Al\init_data\cp2k_test\band_data\aluminum\band.dat") # ["$\Gamma$", "X", "U", "K", "$\Gamma$", "L"] Al # ["$\Gamma$", "K", "M", "$\Gamma$"] graphene # [ # "$\Gamma$", # "T", # "$H_2$", # "L", # "$\Gamma$", # "$S_0$", # "F", # "$\Gamma$", # ] alc # (24, 170, 201) fig, ax = pho.plot_dispersion( kpoints_label=["$\Gamma$", "X", "U", "K", "$\Gamma$", "L"], color="#729CBD", units="1/cm", merge_kpoints=[2, 3], # yticks=range(0, 2000, 400), ) # fig.savefig( # r"D:\Study\Gra-Al\init_data\cp2k_test\band_data\aluminum\band.png", # dpi=300, # bbox_inches="tight", # transparent=True, # ) # print(Phonon.get_supercell(1, 1, 1, "cp2k.inp"))