# Copyright (c) 2022, mushroomfire in Beijing Institute of Technology
# This file is from the mdapy project, released under the BSD 3-Clause License.
import matplotlib.pyplot as plt
import numpy as np
try:
from .plotset import set_figure
except Exception:
from plotset import set_figure
[docs]
class Phonon:
"""This class is used to plot the phonon dispersion based on the band data generated from `phonopy <https://phonopy.github.io/phonopy/>`_.
Args:
filename (str): the filename of band data, such as band.dat, which can be generated by: phonopy-bandplot --gnuplot band.yaml>band.dat
"""
def __init__(self, filename) -> None:
self.filename = filename
self.data, self.kpoints = self._get_band_data()
def _get_band_data(self):
with open(self.filename) as op:
data = op.readlines()
kpoints = np.array(data[1].split()[1:], float)
sepa = [1]
for i, j in enumerate(data):
if j == "\n":
sepa.append(i)
res = {}
num = 0
for i in range(len(sepa) - 1):
pot = np.array([k.split() for k in data[sepa[i] + 1 : sepa[i + 1]]], float)
if len(pot) > 0:
res[f"{num}"] = pot
num += 1
return res, kpoints
def _rgb2hex(self, rgb):
color = "#"
for i in rgb:
color += str(hex(i))[-2:].replace("x", "0").upper()
return color
[docs]
@staticmethod
def get_supercell(rx, ry, rz, inputfile="cp2k.inp", tolerance=1e-3, backend="cp2k"):
if backend == "cp2k":
return f'phonopy --cp2k -c {inputfile} -d --dim="{rx} {ry} {rz}" --tolerance {tolerance}'
elif backend == "vasp":
return f'phonopy -d --dim="{rx} {ry} {rz}" --tolerance {tolerance}'
[docs]
def plot_dispersion(
self,
units="THz",
kpoints_label=None,
yticks=None,
color=None,
merge_kpoints=None,
):
"""This function can plot the phonon dispersion.
Args:
units (str, optional): units of frequency, selected in ['THz', '1/cm']. Defaults to "THz".
kpoints_label (list[str], optional): kpoints label, such as ["$\Gamma$", "K", "M", "$\Gamma$"] for graphene. Defaults to None.
yticks (list[float], optional): y axis ticks, such as [0, 10, 20]. Defaults to None.
color (str | rgb turple, optional): line color, can be a str, such as 'r', '#729CBD', or a rgb turple, such as [125, 125, 125]. Defaults to None.
merge_kpoints (list, optional): sometimes you want to merge two equalvalue points, such as [2, 3]. Defaults to None.
Returns:
tuple: (fig, ax) matplotlib figure and axis class.
"""
fig, ax = set_figure(
figsize=(10, 7.5), bottom=0.08, left=0.16, use_pltset=True, figdpi=200
)
if color is None:
color = "b"
elif isinstance(color, str):
color = color
else:
assert (
len(color) == 3
), "Only support str or a three-elements rgb turple, such as [125, 125, 125]."
color = self._rgb2hex(color)
if merge_kpoints is None:
for i in self.data.keys():
x, y = self.data[i][:, 0], self.data[i][:, 1]
if units == "1/cm":
y *= 33.4
ax.plot(x, y, lw=1.2, c=color)
ax.plot(
[self.kpoints[0], self.kpoints[-1]],
[0, 0],
"--",
c="grey",
lw=1.0,
alpha=0.5,
)
ax.set_xlim(self.kpoints[0], self.kpoints[-1])
ax.set_xticks(self.kpoints)
else:
assert len(merge_kpoints) == 2
assert min(merge_kpoints) >= 0
assert max(merge_kpoints) <= len(self.kpoints) - 1
L, R = merge_kpoints
assert L < R
move = self.kpoints[R] - self.kpoints[L]
for i in self.data.keys():
x, y = self.data[i][:, 0], self.data[i][:, 1]
if units == "1/cm":
y *= 33.4
if (
x.min() >= self.kpoints[L] - 0.01
and x.max() <= self.kpoints[R] + 0.01
):
pass
else:
if x.min() >= self.kpoints[R] - 0.01:
ax.plot(x - move, y, lw=1.2, c=color)
else:
ax.plot(x, y, lw=1.2, c=color)
ax.plot(
[self.kpoints[0], self.kpoints[-1] - move],
[0, 0],
"--",
c="grey",
lw=1.0,
alpha=0.5,
)
ax.set_xlim(self.kpoints[0], self.kpoints[-1] - move)
if R == len(self.kpoints) - 1:
ax.set_xticks(self.kpoints[: L + 1])
else:
ax.set_xticks(
np.hstack(
[
self.kpoints[: L + 1],
(self.kpoints[min(R + 1, len(self.kpoints) - 1) :] - move),
]
)
)
ax.set_xticks([], minor=True)
ax.set_yticks([], minor=True)
if yticks is not None:
ax.set_yticks(yticks)
if kpoints_label is not None:
if len(kpoints_label) == len(self.kpoints):
if merge_kpoints is None:
ax.set_xticklabels(kpoints_label)
else:
if R == len(self.kpoints) - 1:
ax.set_xticklabels(
np.hstack(
[
kpoints_label[:L],
[kpoints_label[L] + "$|$" + kpoints_label[R]],
]
)
)
else:
ax.set_xticklabels(
np.hstack(
[
kpoints_label[:L],
[kpoints_label[L] + "$|$" + kpoints_label[R]],
kpoints_label[min(R + 1, len(kpoints_label) - 1) :],
]
)
)
if units == "1/cm":
ax.set_ylabel("Frequency ($cm^{-1}$)")
else:
ax.set_ylabel("Frequency (THz)")
ylo, yhi = ax.get_ylim()
xticks = ax.get_xticks()
for i in xticks:
ax.plot(
[i, i],
[ylo, yhi],
"--",
lw=0.8,
c="grey",
alpha=0.5,
)
ax.set_ylim(ylo, yhi)
plt.show()
return fig, ax
if __name__ == "__main__":
pho = Phonon(r"D:\Study\Gra-Al\init_data\cp2k_test\band_data\aluminum\band.dat")
# ["$\Gamma$", "X", "U", "K", "$\Gamma$", "L"] Al
# ["$\Gamma$", "K", "M", "$\Gamma$"] graphene
# [
# "$\Gamma$",
# "T",
# "$H_2$",
# "L",
# "$\Gamma$",
# "$S_0$",
# "F",
# "$\Gamma$",
# ] alc
# (24, 170, 201)
fig, ax = pho.plot_dispersion(
kpoints_label=["$\Gamma$", "X", "U", "K", "$\Gamma$", "L"],
color="#729CBD",
units="1/cm",
merge_kpoints=[2, 3],
# yticks=range(0, 2000, 400),
)
# fig.savefig(
# r"D:\Study\Gra-Al\init_data\cp2k_test\band_data\aluminum\band.png",
# dpi=300,
# bbox_inches="tight",
# transparent=True,
# )
# print(Phonon.get_supercell(1, 1, 1, "cp2k.inp"))