Source code for sliceplots.two_dimensional

# -*- coding: utf-8 -*-

"""Module containing useful 2D plotting abstractions on top of matplotlib."""

import matplotlib.transforms as transforms
import numpy as np
from matplotlib.artist import setp, getp
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure

from sliceplots.util import _idx_from_val
from typing import Any, Optional


[docs]class Plot2D: r"""Pseudo-color plot of a 2D array with optional 1D slices attached. Parameters ---------- fig : :class:`~matplotlib.figure.Figure`, optional Empty figure to draw on. If ``None``, a new :class:`~matplotlib.figure.Figure` will be created. Defaults to ``None``. arr2d: :py:class:`np.ndarray` Data to be plotted. h_axis: :py:class:`np.ndarray` Values on the "x" axis. v_axis: :py:class:`np.ndarray` Values on the "y" axis. xlabel: str, optional x-axis label. ylabel: str, optional y-axis label. zlabel: str, optional Label for :py:class:`~matplotlib.colorbar.Colorbar`. kwargs : dict, optional Other plot options, see examples below. Examples -------- .. plot:: :include-source: import numpy as np from matplotlib import pyplot from sliceplots import Plot2D uu = np.linspace(0, np.pi, 128) data = np.cos(uu - 0.5) * np.cos(uu.reshape(-1, 1) - 1.0) fig = pyplot.figure(figsize=(8,8)) Plot2D( fig=fig, arr2d=data, h_axis=uu, v_axis=uu, xlabel=r"$x$ ($\mu$m)", ylabel=r"$y$ ($\mu$m)", zlabel=r"$\rho$ (cm${}^{-3}$)", hslice_val=0.75, vslice_val=2.75, hslice_opts={"color": "#1f77b4", "lw": 1.5, "ls": "-"}, vslice_opts={"color": "#d62728", "ls": "-"}, cmap="viridis", cbar=True, extent=(0, np.pi, 0, np.pi), vmin=-1.0, vmax=1.0, text="your text here", ) """ def __init__( self, *, fig: Optional[Figure] = None, arr2d: np.ndarray, h_axis: np.ndarray, v_axis: np.ndarray, xlabel: Optional[str] = None, ylabel: Optional[str] = None, zlabel: Optional[str] = None, **kwargs: Optional[Any], ): self.extent = kwargs.get( "extent", (np.min(h_axis), np.max(h_axis), np.min(v_axis), np.max(v_axis)) ) # xmin, xmax, ymin, ymax = self.extent xmin_idx, xmax_idx = _idx_from_val(h_axis, xmin), _idx_from_val(h_axis, xmax) ymin_idx, ymax_idx = _idx_from_val(v_axis, ymin), _idx_from_val(v_axis, ymax) # self.data = arr2d[ymin_idx:ymax_idx, xmin_idx:xmax_idx] self.min_data, self.max_data = np.amin(self.data), np.amax(self.data) self.vmin, self.vmax = ( kwargs.get("vmin", self.min_data), kwargs.get("vmax", self.max_data), ) # self.h_axis = h_axis[xmin_idx:xmax_idx] self.v_axis = v_axis[ymin_idx:ymax_idx] # self.label = {"x": xlabel, "y": ylabel, "z": zlabel} # self.cbar = kwargs.get("cbar", True) # see https://matplotlib.org/users/colormapnorms.html self.norm = kwargs.get("norm") # self.hslice_val = kwargs.get("hslice_val") self.vslice_val = kwargs.get("vslice_val") self.hslice_idx = None self.vslice_idx = None if self.hslice_val is not None: self.hslice_idx = _idx_from_val(self.v_axis, self.hslice_val) if self.vslice_val is not None: self.vslice_idx = _idx_from_val(self.h_axis, self.vslice_val) # self.text = kwargs.get("text", "") # if fig is None: # make new figure self.fig = Figure() self.canvas = FigureCanvas(self.fig) else: self.fig = fig self.canvas = self.fig.canvas self.im = None # image to be created by .imshow() self.ax0 = None # main axes self.axh = None # horizontal slice axes self.axv = None # vertical slice axes self._draw_fig(**kwargs) def _main_panel(self, **kwargs): self.im = self.ax0.imshow( self.data, origin="lower", extent=self.extent, aspect="auto", norm=self.norm, interpolation="none", cmap=kwargs.get("cmap", "viridis"), vmin=self.vmin, vmax=self.vmax, ) # self.ax0.set_xlabel(self.label["x"]) self.ax0.set_ylabel(self.label["y"]) def _draw_fig(self, **kwargs): slice_opts = {"ls": "-", "color": "#ff7f0e", "lw": 1.5} # defaults hslice_opts = slice_opts.copy() vslice_opts = slice_opts.copy() # hslice_opts.update(kwargs.get("hslice_opts", {})) vslice_opts.update(kwargs.get("vslice_opts", {})) # # if (self.hslice_idx is None) and (self.vslice_idx is None): gs = GridSpec(1, 1, height_ratios=[1], width_ratios=[1]) self.ax0 = self.fig.add_subplot(gs[0]) self._main_panel(**kwargs) # ---- # elif (self.hslice_idx is not None) and (self.vslice_idx is None): gs = GridSpec(2, 1, height_ratios=[1, 3], width_ratios=[1]) self.ax0 = self.fig.add_subplot(gs[1, 0]) self.axh = self.fig.add_subplot(gs[0, 0], sharex=self.ax0) # self._main_panel(**kwargs) # self.ax0.axhline(y=self.v_axis[self.hslice_idx], **hslice_opts) # trans = transforms.blended_transform_factory( self.ax0.get_yticklabels()[0].get_transform(), self.ax0.transData ) self.ax0.text( 0, self.v_axis[self.hslice_idx], "{:.1f}".format(self.v_axis[self.hslice_idx]), color=hslice_opts["color"], transform=trans, ha="right", va="center", ) # self.axh.set_xmargin(0) self.axh.set_ylabel(self.label["z"]) self.axh.plot(self.h_axis, self.data[self.hslice_idx, :], **hslice_opts) self.axh.set_ylim(self.vmin, self.vmax) # self.axh.xaxis.set_visible(False) # for sp in ("top", "bottom", "right"): self.axh.spines[sp].set_visible(False) # self.fig.subplots_adjust(hspace=0.03) # | # elif (self.vslice_idx is not None) and (self.hslice_idx is None): gs = GridSpec(1, 2, height_ratios=[1], width_ratios=[3, 1]) self.ax0 = self.fig.add_subplot(gs[0, 0]) self.axv = self.fig.add_subplot(gs[0, 1], sharey=self.ax0) # self._main_panel(**kwargs) # self.ax0.axvline(x=self.h_axis[self.vslice_idx], **vslice_opts) # trans = transforms.blended_transform_factory( self.ax0.transData, self.ax0.get_xticklabels()[0].get_transform() ) self.ax0.text( self.h_axis[self.vslice_idx], 0, "{:.1f}".format(self.h_axis[self.vslice_idx]), color=vslice_opts["color"], transform=trans, ha="center", va="top", ) # self.axv.set_ymargin(0) self.axv.set_xlabel(self.label["z"]) self.axv.plot(self.data[:, self.vslice_idx], self.v_axis, **vslice_opts) self.axv.set_xlim(self.vmin, self.vmax) # self.axv.yaxis.set_visible(False) # for sp in ("top", "left", "right"): self.axv.spines[sp].set_visible(False) # self.fig.subplots_adjust(wspace=0.03) # --|-- # else: gs = GridSpec(2, 2, height_ratios=[1, 3], width_ratios=[3, 1]) self.ax0 = self.fig.add_subplot(gs[1, 0]) self.axh = self.fig.add_subplot(gs[0, 0], sharex=self.ax0) self.axv = self.fig.add_subplot(gs[1, 1], sharey=self.ax0) # self._main_panel(**kwargs) # self.ax0.axhline(y=self.v_axis[self.hslice_idx], **hslice_opts) # ##----## self.ax0.axvline(x=self.h_axis[self.vslice_idx], **vslice_opts) # ## | ## # --- # trans = transforms.blended_transform_factory( self.ax0.get_yticklabels()[0].get_transform(), self.ax0.transData ) self.ax0.text( 0, self.v_axis[self.hslice_idx], "{:.1f}".format(self.v_axis[self.hslice_idx]), color=hslice_opts["color"], transform=trans, ha="right", va="center", ) # | # trans = transforms.blended_transform_factory( self.ax0.transData, self.ax0.get_xticklabels()[0].get_transform() ) self.ax0.text( self.h_axis[self.vslice_idx], 0, "{:.1f}".format(self.h_axis[self.vslice_idx]), color=vslice_opts["color"], transform=trans, ha="center", va="top", ) # --- # self.axh.set_xmargin(0) # otherwise ax0 may have white margins self.axh.set_ylabel(self.label["z"]) self.axh.plot(self.h_axis, self.data[self.hslice_idx, :], **hslice_opts) self.axh.set_ylim(self.vmin, self.vmax) # | # self.axv.set_ymargin(0) self.axv.set_xlabel(self.label["z"]) self.axv.plot(self.data[:, self.vslice_idx], self.v_axis, **vslice_opts) self.axv.set_xlim(self.vmin, self.vmax) # hide the relevant axis self.axh.xaxis.set_visible(False) # - self.axv.yaxis.set_visible(False) # | # "Despine" the slice profiles for ax, spines in ( (self.axh, ("top", "bottom", "right")), (self.axv, ("top", "left", "right")), ): # for sp in spines: ax.spines[sp].set_visible(False) # self.fig.subplots_adjust(wspace=0.03, hspace=0.03) # self.ax0.text( 0.02, 0.02, self.text, transform=self.ax0.transAxes, color="#ff7f0e" ) # if self.cbar: cax = inset_axes(self.ax0, width="70%", height="3%", loc=9) cbar = self.fig.colorbar(self.im, cax=cax, orientation="horizontal") cbar.set_label(self.label["z"], color="#ff7f0e") cbar.ax.xaxis.set_ticks_position("top") cbar.ax.xaxis.set_label_position("top") cbar.ax.tick_params(color="#ff7f0e", width=1.5, labelsize=8) cbxtick_obj = getp(cbar.ax.axes, "xticklabels") setp(cbxtick_obj, color="#ff7f0e")