Source code for finchnmr.analysis

"""
Tools to analyze models.

Authors: Nathan A. Mahynski
"""
import copy
import matplotlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px

from . import substance
from . import model  # Only needed for type checking

from typing import Union, Any, ClassVar
from numpy.typing import NDArray


[docs]def plot_stacked_importances( optimized_models: list[Any], figsize: Union[tuple[int, int], None] = None, backend: str = "mpl", **imshow_kwargs: Any, ): """ Plot the importance values in list of models. Parameters ---------- optimized_models : list List of fitted models (see `optimize_models`). figsize : tuple(int, int), optional(default=None) Figure size; this is currently only supported for the matplotlib backend. backend : str, optional(default='mpl') Plotting library to use; the default 'mpl' uses matplotlib and is not interactive, while 'plotly' will yield interactive plots. imshow_kwargs : dict, optional(default=None) Additional keyword arguments for {backend}.imshow() function; e.g., "cmap" or "color_continuous_scale". Returns ------- if backend == 'mpl': image : matplotlib.image.AxesImage Feature importances as an image of a grid where each column corresponds to a different model and each row to a different feature (in the unrolled HSQC NMR spectrum). colorbar : matplotlib.pyplot.colorbar Colorbar to go with the image. if backend == 'plotly': image : plotly.graph_objs._figure.Figure Feature importances as an image of a grid where each column corresponds to a different model and each row to a different feature (in the unrolled HSQC NMR spectrum). Example ------- >>> optimized_models, analyses = finchnmr.model.optimize_models(...) >>> plot_stacked_importances(optimized_models, backend='mpl', cmap='RdBu') >>> plot_stacked_importances(optimized_models, backend='plotly', color_continuous_scale='RdBu') """ imp_list = [m.importances() for m in optimized_models] imps_array = np.stack(imp_list) if backend == "mpl": _, ax = plt.subplots(figsize=figsize) image_mpl = ax.imshow(imps_array.T, **imshow_kwargs) ax.set_xlabel("Model Index") ax.set_ylabel("Library Substance Index") colorbar = plt.colorbar(image_mpl, ax=ax) colorbar.set_label("Importance") return image_mpl, colorbar elif backend == "plotly": image_plt = px.imshow( imps_array.T, text_auto=False, aspect="auto", **imshow_kwargs ) image_plt.update_layout(xaxis_title="Model Index") image_plt.update_layout(yaxis_title="Library Substance Index") image_plt.update_layout(coloraxis_colorbar=dict(title="Importance")) # https://stackoverflow.com/questions/73649907/plotly-express-imshow-hover-text/73658192#73658192 names = np.array( [ [s.name for s in m._nmr_library._substances] for m in optimized_models ] ).T z1, z2 = np.dstack([np.indices(imps_array.T.shape)]) customdata = np.dstack((z1, z2, names)) image_plt.update( data=[ { "customdata": customdata, "hovertemplate": "Substance Index: %{customdata[0]}<br>Model Index: %{customdata[1]}<br>Substance: %{customdata[2]}", } ] ) return image_plt else: raise ValueError(f"Unrecognized backend {backend}")
[docs]class Analysis: """Set of analysis methods for analyzing fitted models.""" _model: ClassVar["model._Model"] def __init__(self, model: "model._Model") -> None: """ Instantiate the class. Parameters ---------- model : _Model Fitted model to some target, see `optimize_models`. """ setattr(self, "_model", model)
[docs] def get_top_substances( self, k: int = 5, index: bool = False ) -> tuple[Union[list["substance.Substance"], list[int]], list[float]]: """ Retrieve the most important substances to the model. Parameters ---------- k : int, optional(default=5) Number of most important spectra to retrieve. If -1 then get them all. index : bool, optional(default=False) If True, then return list of substance indices in the model library not the substance itself. Returns ------- top_substances : list(Substance) of list(int) The most important substances, sorted from highest to lowest by the absolute value of their importance. If `index=True` then this is a list of integers corresponding to the index of the substance in the model's library. top_importances : list(float) Importance of each substance, sorted from highest to lowest by the absolute value of their importance. """ if k == -1: k = len(self._model.importances()) top_substances: Union[list["substance.Substance"], list[int]] = [] top_importances = [] for i, (idx_, importances_) in enumerate( sorted( list(enumerate(self._model.importances())), key=lambda x: np.abs(x[1]), reverse=True, )[:k] ): if index: top_substances.append(idx_) # type: ignore[arg-type] else: top_substances.append(self._model._nmr_library.substance_by_index(idx_)) # type: ignore[arg-type] top_importances.append(importances_) return top_substances, top_importances
[docs] def plot_top_spectra( self, k: int = 5, plot_width: int = 3, figsize: Union[tuple[int, int], None] = (10, 5), ) -> NDArray: """ Plot the HSQC NMR spectra that are the most importance to the model using matplotlib. To visualize these results using another plotting backend, such as plotly, use `.get_top_substances` and create subplots as desired. Parameters ---------- k : int, optional(default=5) Number of most important spectra to plot. If -1 then plot them all. plot_width : int, optional(default=3) Number of subplots the grid will have along its width. figsize : tuple(int, int), optional(default=(10,5)) Size of final figure. Returns ------- axes : ndarray(matplotlib.pyplot.Axes, ndim=1) Flattened array of axes on which the top HSQC NMR spectra are plotted. """ if k == -1: k = len(self._model.importances()) plot_depth = int(np.ceil(k / plot_width)) fig, axes_ = plt.subplots( nrows=plot_depth, ncols=plot_width, figsize=figsize ) axes = axes_.flatten() # Plot the NMR spectra top_substances, top_importances = self.get_top_substances( k=k, index=False ) for i, (s_, importance_) in enumerate( zip(top_substances, top_importances) ): s_.plot(ax=axes[i]) # type: ignore[attr-defined] axes[i].set_title( s_.name + "\nI = {}".format("%.4f" % importance_) # type: ignore[attr-defined] ) # Trim of extra subplots for i in range(k, plot_depth * plot_width): axes[i].remove() return axes
[docs] def plot_top_importances( self, k: int = 5, by_name: bool = False, figsize: Union[tuple[int, int], None] = None, backend: str = "mpl", ): """ Plot the importances of the top substances in the model. Parameters ---------- k : int, optional(default=5) Number of top importances to plot. If -1 then plot them all. by_name : bool, optional(default=False) HSQC NMR spectra will given by integer index in the library by default; if True, the use the associated substance name instead. figsize : tuple(int, int), optional(default=None)) Size of final figure. backend : str, optional(default='mpl') Plotting library to use; the default 'mpl' uses matplotlib and is not interactive, while 'plotly' will yield interactive plots. Returns ------- if backend == 'mpl': axes : matplotlib.pyplot.Axes Horizontal bar chart the importances are plotted on in descending order. if backend == 'plotly': figure : plotly.graph_objs._figure.Figure Horizontal bar chart the importances are plotted on in descending order. """ if k == -1: k = len(self._model.importances()) top_substances, top_importances = self.get_top_substances(k, index=True) if by_name: labels = [ self._model._nmr_library.substance_by_index(idx_).name for idx_ in top_substances # type: ignore[arg-type] ] else: labels = [str(idx_) for idx_ in top_substances] if backend == "mpl": _, axes = plt.subplots(nrows=1, ncols=1, figsize=figsize) axes.barh( y=np.arange(k)[::-1], width=[imp_ for imp_ in top_importances], align="center", tick_label=labels, ) axes.set_xlabel("Importance") axes.set_xscale("log") return axes elif backend == "plotly": df = pd.DataFrame( data=[ [l_, imp_] for (l_, imp_) in zip(labels, top_importances) ], columns=["Label", "Importance"], ) fig = px.bar( data_frame=df.iloc[:k], x="Importance", y="Label", orientation="h", log_x=True, ) fig.update_yaxes(autorange="reversed") fig.update_layout(barcornerradius=7) fig.update_layout(yaxis_title="") return fig else: raise ValueError(f"Unrecognized backend {backend}")
[docs] def build_residual(self) -> "substance.Substance": """ Create a substance whose spectrum is comprised of the residual (true spectrum - model). Returns ------- residual : substance.Substance Artificial substance whose spectrum is the residual. """ target = self._model.target() reconstructed = self._model.reconstruct() residual = copy.deepcopy( self._model.target() ) # Create a new copy of the target as a baseline residual._set_data(target.data - reconstructed.data) return residual
[docs] def plot_residual(self, **kwargs): """ Plot the residual (target - reconstructed) spectrum. An artificial substance is created representing the residual (see `build_residual`). This is what is plotted, so it may be manipulated accordingly. Refer to the kwargs in `substance.Substance.plot`. Parameters ---------- kwargs : dict, optional(default=None) Keyword arguments for `substance.Substance.plot`. Returns ------- By default, or if kwargs['backend'] == 'mpl' in kwargs: image : matplotlib.image.AxesImage HSQC NMR resdual spectrum as an image. colorbar : matplotlib.colorbar.Colorbar Colorbar to go with the image. if kwargs['backend'] == 'plotly': image : plotly.graph_objs._figure.Figure HSQC NMR spectrum as an image. Example ------- >>> a = Analysis(...) >>> a.plot_residual(absolute_values=True, backend='mpl') >>> a.plot_residual(absolute_values=True, backend='plotly', cmap='viridis') """ residual = self.build_residual() result = residual.plot(**kwargs) if "backend" in kwargs: if kwargs["backend"] == "mpl": plt.gca().set_title("Target - Reconstructed") elif kwargs["backend"] == "plotly": result.update_layout(title="Target - Reconstructed") else: raise ValueError(f"Unrecognized backend {backend}") return result