# Copyright (c) 2022, mushroomfire in Beijing Institute of Technology
# This file is from the mdapy project, released under the BSD 3-Clause License.
import matplotlib.pyplot as plt
from cycler import cycler
[docs]
def pltset(color_cycler=None, **kargs):
"""This is a drawing template function, which refers to the science style from `SciencePlots <https://github.com/garrettj403/SciencePlots>`_
and color also comes from `Paul Tot's website <https://personal.sron.nl/~pault/>`_.
Args:
color_cycler (str | list[str] | tuple[str], optional): One can use other color cycler belongs to ['bright', 'high-contrast', 'high-vis', 'light', 'muted', 'retro', 'std-colors', 'vibrant'], or directly provide a list or tuple contains colors, such as ['#FFAABB', '#99DDFF', '#44BB99']. Defaults to None.
One also can modify the default parameters by provide a dict. For example, if you want to use 'Times New Roman' font and hide the minor ticks: you can try:
pltset(**{'font.serif':'Times New Roman', 'xtick.minor.visible':False, 'ytick.minor.visible':False})
If you want to change the axes linewidth and major tick width, and make bolded font, you can try:
pltset(**{"xtick.major.width":1., "ytick.major.width":1., "axes.linewidth":1., "font.weight":'bold', "axes.labelweight":'bold'})
More parameters can be found in plt.rcParams.keys().
"""
plt.rcParams.clear()
if color_cycler is None:
plt.rcParams["axes.prop_cycle"] = cycler(
"color",
[
"#0C5DA5",
"#00B945",
"#FF9500",
"#FF2C00",
"#845B97",
"#474747",
"#9e9e9e",
],
)
elif color_cycler == "bright":
plt.rcParams["axes.prop_cycle"] = cycler(
"color",
[
"#4477AA",
"#EE6677",
"#228833",
"#CCBB44",
"#66CCEE",
"#AA3377",
"#BBBBBB",
],
)
elif color_cycler == "high-contrast":
plt.rcParams["axes.prop_cycle"] = cycler(
"color", ["#004488", "#DDAA33", "#BB5566"]
)
elif color_cycler == "high-vis":
plt.rcParams["axes.prop_cycle"] = cycler(
"color", ["#0d49fb", "#e6091c", "#26eb47", "#8936df", "#fec32d", "#25d7fd"]
) + cycler("ls", ["-", "--", "-.", ":", "-", "--"])
elif color_cycler == "light":
plt.rcParams["axes.prop_cycle"] = cycler(
"color",
[
"#77AADD",
"#EE8866",
"#EEDD88",
"#FFAABB",
"#99DDFF",
"#44BB99",
"#BBCC33",
"#AAAA00",
"#DDDDDD",
],
)
elif color_cycler == "muted":
plt.rcParams["axes.prop_cycle"] = cycler(
"color",
[
"#CC6677",
"#332288",
"#DDCC77",
"#117733",
"#88CCEE",
"#882255",
"#44AA99",
"#999933",
"#AA4499",
"#DDDDDD",
],
)
elif color_cycler == "retro":
plt.rcParams["axes.prop_cycle"] = cycler(
"color", ["#4165c0", "#e770a2", "#5ac3be", "#696969", "#f79a1e", "#ba7dcd"]
)
elif color_cycler == "std-colors":
plt.rcParams["axes.prop_cycle"] = cycler(
"color",
[
"#0C5DA5",
"#00B945",
"#FF9500",
"#FF2C00",
"#845B97",
"#474747",
"#9e9e9e",
],
)
elif color_cycler == "vibrant":
plt.rcParams["axes.prop_cycle"] = cycler(
"color",
[
"#EE7733",
"#0077BB",
"#33BBEE",
"#EE3377",
"#CC3311",
"#009988",
"#BBBBBB",
],
)
elif isinstance(color_cycler, list) or isinstance(color_cycler, tuple):
plt.rcParams["axes.prop_cycle"] = cycler(
"color", [color for color in color_cycler]
)
else:
raise "color_style must belong to ['bright', 'high-contrast', 'high-vis', 'light', 'muted', 'retro', 'std-colors', 'vibrant'], or a list, or a tuple!"
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["xtick.major.size"] = 3
plt.rcParams["xtick.major.width"] = 0.5
plt.rcParams["xtick.minor.size"] = 1.5
plt.rcParams["xtick.minor.width"] = 0.5
plt.rcParams["xtick.minor.visible"] = True
plt.rcParams["xtick.top"] = True
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["ytick.major.size"] = 3
plt.rcParams["ytick.major.width"] = 0.5
plt.rcParams["ytick.minor.size"] = 1.5
plt.rcParams["ytick.minor.width"] = 0.5
plt.rcParams["ytick.minor.visible"] = True
plt.rcParams["ytick.right"] = True
plt.rcParams["axes.linewidth"] = 0.5
plt.rcParams["lines.linewidth"] = 1.5
plt.rcParams["lines.markersize"] = 4
plt.rcParams["legend.frameon"] = False
plt.rcParams["font.serif"] = ["cmr10", "Computer Modern Serif", "DejaVu Serif"]
plt.rcParams["font.family"] = "serif"
plt.rcParams["axes.formatter.use_mathtext"] = True
plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams["font.size"] = 11
for key, value in kargs.items():
try:
plt.rcParams[key] = value
except Exception as e:
print(e)
pass
[docs]
def pltset_old(color_cycler=None, **kargs):
"""This function used to generate the same style with mdapy<0.9.9. Note that the
'Times New Roman' font is required.
Args:
color_cycler (str | list[str] | tuple[str], optional): One can use other color cycler belongs to ['bright', 'high-contrast', 'high-vis', 'light', 'muted', 'retro', 'std-colors', 'vibrant'], or directly provide a list or tuple contains colors, such as ['#FFAABB', '#99DDFF', '#44BB99']. Defaults to None.
One also can modify the default parameters by provide a dict. For example, if you want to let the font normal and set the font size, you can try:
pltset_old(**{"font.weight":'normal', "axes.labelweight":'normal', "font.size" : 10})
More parameters can be found in plt.rcParams.keys().
"""
pltset(color_cycler=color_cycler)
default = {
"xtick.major.width": 1.0,
"ytick.major.width": 1.0,
"axes.linewidth": 1.0,
"font.weight": "bold",
"axes.labelweight": "bold",
"font.serif": "Times New Roman",
"xtick.minor.visible": False,
"ytick.minor.visible": False,
"xtick.top": False,
"ytick.right": False,
"font.size": 12,
"legend.fontsize": 10,
}
for key, value in default.items():
try:
plt.rcParams[key] = value
except Exception as e:
print(e)
pass
for key, value in kargs.items():
try:
plt.rcParams[key] = value
except Exception as e:
print(e)
pass
[docs]
def cm2inch(value):
"""Centimeters to feet.
Args:
value (float): centimeters.
Returns:
float: feet.
"""
return value / 2.54
if __name__ == "__main__":
# pltset_old(color_cycler="high-vis", **{"font.weight":'normal', "axes.labelweight":'normal', })
pltset_old()
# pltset_old(**{"font.weight":'normal', "axes.labelweight":'normal', })
# pltset('high-vis', **{'font.serif':'Times New Roman', 'xtick.minor.visible':False, 'ytick.minor.visible':False,})
fig, ax = set_figure(
ncol=2,
nrow=1,
figsize=(16, 6),
wspace=0.3,
use_pltset=False,
use_pltset_old=False,
)
for i in range(3):
ax[0].plot(range(i, 10 + i), label=i)
ax[1].plot(range(12), "o-")
for i in range(2):
ax[i].set_xlabel("x label")
ax[i].set_ylabel("y label")
ax[0].legend()
plt.show()