# Copyright (c) 2022-2026, Yongchao Wu in Aalto University
# This file is part of the mdapy project, released under the BSD 3-Clause License.
"""
Plotting Utilities for Scientific Publications
===============================================
This module provides utilities for creating publication-quality figures with Matplotlib.
It includes functions for configuring plot styles, creating figures with consistent
formatting, and saving figures with uniform margins.
The module is designed to produce figures suitable for scientific papers and presentations,
with sensible defaults for font sizes, line widths, tick marks, and color schemes.
Note: This module requires matplotlib to be installed. Install it with:
pip install matplotlib
Functions
---------
set_figure : Create a figure with scientific style settings
save_figure : Save a figure with uniform whitespace margins
_pltset : Configure global Matplotlib style (internal)
_cm2inch : Convert centimeters to inches (internal)
_ensure_matplotlib : Check matplotlib availability (internal)
Examples
--------
Basic usage for creating a simple plot:
>>> import numpy as np
>>> fig, ax = set_figure(figsize=(8.5, 7.0))
>>> x = np.linspace(0, 2 * np.pi, 100)
>>> ax.plot(x, np.sin(x), label="sin(x)")
>>> ax.set_xlabel("x")
>>> ax.set_ylabel("y")
>>> ax.legend()
>>> save_figure(fig, "output.png")
Creating a multi-panel figure:
>>> fig, axes = set_figure(figsize=(17, 7), nrow=1, ncol=2)
>>> for i, ax in enumerate(axes):
... ax.plot(x, np.sin(x * (i + 1)))
... ax.set_xlabel("x")
... ax.set_ylabel("y")
>>> save_figure(fig, "multi_panel.pdf")
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union, List, Tuple, Any, Literal
import numpy as np
if TYPE_CHECKING:
from matplotlib.figure import Figure
from matplotlib.axes import Axes
def _ensure_matplotlib() -> None:
"""
Ensure matplotlib and its dependencies are available.
This function performs lazy import of matplotlib and cycler, only loading
them when plotting functions are actually called. This allows the parent
package to be installed without matplotlib as a hard dependency.
Raises
------
ImportError
If matplotlib or cycler is not installed, with instructions on how to install.
Notes
-----
This function imports matplotlib modules into the local namespace of each
calling function, not into the module namespace.
"""
try:
import matplotlib.pyplot # noqa: F401
from matplotlib.figure import Figure # noqa: F401
from matplotlib.axes import Axes # noqa: F401
from matplotlib.transforms import Bbox # noqa: F401
from cycler import cycler # noqa: F401
except ImportError as e:
raise ImportError(
"Matplotlib is required for plotting functionality but is not installed.\n"
"Please install it using one of the following methods:\n"
" pip install matplotlib\n"
f"Original error: {e}"
) from e
def _pltset(
color_cycler: Optional[Union[List[str], Tuple[str, ...]]] = None, **kwargs: Any
) -> None:
"""
Configure global Matplotlib style optimized for scientific publications.
Parameters
----------
color_cycler : list of str or tuple of str, optional
Custom color palette for plot lines.
**kwargs : Any
Additional keyword arguments to override specific rcParams settings.
"""
_ensure_matplotlib()
# Import here after ensuring matplotlib is available
import matplotlib.pyplot as plt
from cycler import cycler
plt.rcParams.clear()
if color_cycler is None:
color_cycler = [
"#4477AA", # Blue
"#EE6677", # Red
"#228833", # Green
"#CCBB44", # Yellow
"#66CCEE", # Cyan
"#AA3377", # Purple
"#BBBBBB", # Gray
]
plt.rcParams["axes.prop_cycle"] = cycler("color", color_cycler)
plt.rcParams.update(
{
# X-axis tick configuration
"xtick.direction": "in",
"xtick.major.size": 3,
"xtick.major.width": 0.6,
"xtick.minor.size": 1.5,
"xtick.minor.width": 0.6,
"xtick.top": True,
"xtick.minor.visible": False,
# Y-axis tick configuration
"ytick.direction": "in",
"ytick.major.size": 3,
"ytick.major.width": 0.6,
"ytick.minor.size": 1.5,
"ytick.minor.width": 0.6,
"ytick.right": True,
"ytick.minor.visible": False,
# Line and axes styling
"axes.linewidth": 0.6,
"lines.linewidth": 1.2,
"lines.markersize": 3,
# Font configuration
"font.weight": "normal",
"font.size": 10.0,
"axes.labelweight": "normal",
"legend.frameon": False,
"legend.fontsize": 9.0,
"axes.titlesize": 9.0,
"font.family": "serif",
"font.serif": ["Times New Roman", "Arial", "cmr10"],
# Mathematical text
"axes.formatter.use_mathtext": True,
"mathtext.fontset": "cm",
}
)
# Apply custom overrides
for key, value in kwargs.items():
if key in plt.rcParams:
plt.rcParams[key] = value
else:
print(f"Warning: '{key}' is not a valid rcParam key and will be ignored.")
def _cm2inch(value: Union[float, int]) -> float:
"""
Convert centimeters to inches for Matplotlib figure sizing.
Parameters
----------
value : float or int
Size in centimeters to convert.
Returns
-------
float
Equivalent size in inches.
"""
return value / 2.54
if __name__ == "__main__":
fig, axes = set_figure(
figsize=(17, 14),
ncol=2,
nrow=2,
**{
"font.size": 10.0,
"lines.linewidth": 1.4,
},
)
x = np.linspace(0, 7, 100)
for i, j in enumerate(((0, 0), (0, 1), (1, 0), (1, 1))):
ax = axes[j[0]][j[1]]
ax.plot(x, np.sin(x) * (i + 1), label=f"{i + 1}·sin(x)")
ax.set_xlabel("X (arb)")
ax.set_ylabel("Y (arb)")
ax.legend()
# save_figure(fig, "test_uniform.png", transparent=False)
import matplotlib.pyplot as plt
plt.show()