import os
from warnings import warn
import numpy as np
import networkx as nx
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from matplotlib import colors as cl
from plotly import graph_objs as go
from plotly.offline import plot
from plotly import tools
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource
from bokeh.models.widgets import DataTable, TableColumn
from bokeh.embed import file_html
from bokeh.resources import CDN
from bokeh.io import curdoc
from bokeh.themes import Theme
from pyspectools import routines
"""
Commonly used formatting options
These are generally parts of Matplotlib that I commonly change,
but always have to look up stack overflow to find out how to do...
"""
[docs]def strip_spines(spines, axis):
# Function for removing the spines from an axis.
for spine in spines:
axis.spines[spine].set_visible(False)
[docs]def no_scientific(axis):
# Turns off scientific notation for the axes
axis.get_xaxis().get_major_formatter().set_useOffset(False)
axis.get_yaxis().get_major_formatter().set_useOffset(False)
"""
Specific figure types
These include recipes/routines to generate commonly used
figures, providing the data is provided in a digestable
way for the scripts.
These include:
- Polyad diagrams
- Generic energy diagrams
- Adding images to a matplotlib figure
"""
[docs]def make_pes(x, energies, width=5):
""" Function to create a smooth PES, where stationary
points are connected.
Basically makes a linearly spaced x array which
pads x-axis density to create the horizontal line
for the stationary points
Optional arguments are how many x-points to
pad with on either side.
"""
new_x = np.array([])
new_energies = list()
for xvalue, energy in zip(x, energies):
new_x = np.append(
new_x,
np.linspace(xvalue - (width * 0.05), xvalue + (width * 0.05), width * 2),
)
new_energies.append([energy] * width * 2)
return new_x, np.array(new_energies).flatten()
[docs]def calc_vibE(quant_nums, vibrations):
"""
Function that will calculate a 1D array of
vibrational energies, based on a nested list
of quantum numbers and a list of vibrational
frequencies.
"""
energies = list()
for state in quant_nums:
energies.append({"state": state, "energy": np.sum(state * vibrations)})
return energies
[docs]def generate_x_coord(quant_nums, energies):
"""
Function to generate the x coordinate for plotting
based on the quantum numbers specified.
"""
# Get boolean mask, and find the unique combinations
unique_combos = np.unique(quant_nums > 0, axis=0)
n_unique = len(unique_combos)
cat_dict = {x: list() for x in np.arange(n_unique)}
for item in energies:
index = np.where((unique_combos == (item["state"] > 0)).all(axis=1))[0][0]
cat_dict[index].append(item)
return cat_dict
[docs]def make_elevel_plot(cat_dict, axis, color, maxE, alpha=1.0):
"""
make_elevel_plot is used to create energy level diagrams
with Matplotlib axis objects. This is considered slightly
lower level: you should call the higher level wrapper functions
like vib_energy_diagram instead!
The function will loop over every transition that has
been categorized/binned into an x-axis index, and sub-
sequently plot each x-bin individually.
The input arguments are:
cat_dict: dict-like with keys as the x-axis values, and
the items are dictionaries containing the energy and
quantum number configuration.
"""
# Spacing defines the x-axis unit spacing
spacing = 5
# Defines the width of the levels
width = 1.0
# Loop over every state
for index in cat_dict:
xs = list()
ys = list()
annotations = list()
for item in cat_dict[index]:
xs.append(index)
ys.append(item["energy"])
annotations.append(str(tuple(item["state"])).replace(",", ""))
axis.hlines(
ys,
[x - (width / 2.0) + spacing for x in xs],
[x + (width / 2.0) + spacing for x in xs],
color=color,
alpha=alpha,
)
# Loop over configurations and annotate the energy levels
# with their quantum numbers
for x, y, text in zip(xs, ys, annotations):
# Providing the value is less than the maximum specified
# energy, add the annotation
if y < maxE:
axis.text(
x + spacing,
y + 30.0,
text,
horizontalalignment="center",
color=color,
alpha=alpha,
size=8.0,
)
[docs]def add_image(axis, filepath, zoom=0.15, position=[0.0, 0.0]):
"""
Function to add an image annotation to a specified axis.
Takes the matplotlib axis and filepath to the image as input,
and optional arguments for the scaling (zoom) and position of
the image in axis units.
"""
image = OffsetImage(plt.imread(filepath, format="png"), zoom=zoom)
image.image.axes = axis
box = AnnotationBbox(
image, position, xybox=position, xycoords="data", frameon=False
)
axis.add_artist(box)
[docs]def vib_energy_diagram(
quant_nums, vibrations, maxV=2, maxE=3000.0, useFull=True, image=None, imagesize=0.1
):
"""
Function that will generate a vibrational energy diagram.
This function wraps the make_elevel_plot function!
Input arguments are
quant_nums: A 2D numpy array of quantum numbers, where each row is
a single configuration.
vibrations: A list of vibrational frequencies.
maxV: The maximum quanta for considering predicted frequencies.
maxE: Maximum energy for plotting.
useFull: Boolean for flagging whether predict frequencies are plotted.
image: String-like for specifying path to an image.
"""
# Generate a list of possible quanta from 0 to maxV
full_quant = np.arange(0, maxV)
# Create a generator that will provide every possible configuration
# of quantum numbers
full_combo = np.array(list(product(full_quant, repeat=len(vibrations))))
# Calculate the energies of all possible configurations
full_energies = calc_vibE(full_combo, vibrations)
full_cat_dict = generate_x_coord(full_combo, full_energies)
energies = calc_vibE(quant_nums, vibrations)
# useFull denotes whether or not to use the predicted combination
# energies. If False, we only display the vibrations observed
if useFull is True:
cat_dict = generate_x_coord(full_combo, energies)
else:
cat_dict = generate_x_coord(quant_nums, energies)
# Initialize the figure object
fig, ax = plt.subplots(figsize=(5, 5.5))
# If we want to show predictions, plot them up too
if useFull is True:
make_plot(full_cat_dict, ax, "black", maxE, 0.6)
# Call function to plot up diagram
make_plot(cat_dict, ax, "#e41a1c", maxE)
# Set various labelling
ax.set_xticklabels([])
ax.set_xticks([])
ax.set_xlabel("Vibrational state")
ax.set_ylabel("Energy (cm$^{-1}$)")
ax.set_ylim([-50.0, maxE])
minx = ax.get_xlim()[0]
maxx = ax.get_xlim()[1]
# Annotate the fundamentals
for vibration in vibrations:
ax.hlines(
vibration,
minx,
maxx,
color="#377eb8",
linestyle="--",
zorder=0.0,
alpha=0.3,
)
# If an image filepath is specified, then plot it up on the edge
if image is not None:
add_image(ax, image, zoom=imagesize, position=[maxx - 0.5, vibration])
fig.tight_layout()
return fig, ax
[docs]def overlay_dr_spectrum(
dataframe, progressions, freq_col="Frequency", int_col="Intensity", **kwargs
):
layout = define_layout("Frequency (MHz)", "Intensity")
fig = go.FigureWidget(layout=layout)
fig.add_scattergl(
x=dataframe[freq_col], y=dataframe[int_col], name="Observation", opacity=0.4
)
colors = generate_colors(len(progressions), cmap=plt.cm.tab10)
level = 2.0
for index, (progression, color) in enumerate(zip(progressions, colors)):
mask = np.where(progression <= np.max(dataframe[freq_col]))
progression = progression[mask]
indices = np.array(
[routines.find_nearest(dataframe[freq_col], freq) for freq in progression]
)
indices = indices[:, 1]
y = dataframe[int_col].iloc[indices] * 1.2
fig.add_scattergl(
x=progression,
y=y,
marker={"color": color},
mode="markers+lines",
hoverinfo="name+x",
name="Progression {}".format(index),
)
return fig
[docs]def dr_network_diagram(connections, **kwargs):
"""
Use NetworkX to create a graph with nodes corresponding to cavity
frequencies, and vertices as DR connections.
The color map can be specified by passing kwargs.
:param connections: list of 2-tuples corresponding to pairs of connections
:return
"""
graph = nx.Graph()
nodes = [graph.add_node(frequency) for frequency in np.unique(connections)]
vertices = [graph.add_edge(*pair) for pair in connections]
# Generate positions based on the shell layout that's typical of DR connections
# Frequencies are sorted in anti-clockwise order, starting at 3 o'clock
positions = nx.shell_layout(graph)
color_kwarg = {"cmap": plt.cm.tab10}
if "cmap" in kwargs:
color_kwarg.update(**kwargs)
coords = np.array(list(positions.values()))
connected = list(nx.connected_components(graph))
colors = generate_colors(len(connected), **color_kwarg)
fig_layout = {
"height": 700.0,
"width": 700.0,
"autosize": True,
"xaxis": {
"showgrid": False,
"zeroline": False,
"ticks": "",
"showticklabels": False,
},
"yaxis": {
"showgrid": False,
"zeroline": False,
"ticks": "",
"showticklabels": False,
},
"showlegend": False,
}
fig = go.FigureWidget(layout=fig_layout)
# Draw the nodes
fig.add_scattergl(
x=coords[:, 0],
y=coords[:, 1],
text=list(np.unique(connections)),
hoverinfo="text",
mode="markers",
)
# Draw the vertices
for connectivity, color in zip(connected, colors):
# Get all of the coordinates associated with edges within a series
# of connections
coords = np.array([positions[node] for node in sorted(connectivity)])
fig.add_scattergl(
x=coords[:, 0],
y=coords[:, 1],
mode="lines",
hoverinfo=None,
name="",
opacity=0.4,
marker={"color": color},
)
return fig, connected
[docs]def init_plotly_subplot(nrows, ncols, **kwargs):
"""
Initialize a Plotly subplot.
:param nrows: number of rows for the subplot
:param ncols: number of columns for the subplot
:return: plotly FigureWidget object
"""
subplot = tools.make_subplots(rows=nrows, cols=ncols, **kwargs)
fig = go.FigureWidget(subplot)
return fig
[docs]def stacked_plot(
dataframe, frequencies, freq_range=0.002, freq_col="Frequency", int_col="Intensity"
):
"""
Create a Loomis-Wood style plot via a list of frequencies, and a broadband
spectrum. The keyword freq_range will use a percentage of the center
frequency to extend the frequency range that gets plotted.
:param dataframe: pandas DataFrame
:param frequencies: iterable with float frequencies to use as centers
:param freq_range: decimal percentage to specify the range to plot
:param freq_col: str name for the column to use as the frequency axis
:param int_col: str name for the column to use as the intensity axis
:return fig: Plotly FigureWidget with the subplots
"""
plot_func = go.Scattergl
# Want the frequencies in ascending order, going upwards in the plot
indices = np.where(
np.logical_and(
dataframe[freq_col].min() <= frequencies,
frequencies <= dataframe[freq_col].max(),
)
)
# Plot only frequencies within band
frequencies = frequencies[indices]
frequencies = np.sort(frequencies)[::-1]
nplots = len(frequencies)
titles = tuple("{:.0f} MHz".format(frequency) for frequency in frequencies)
fig = init_plotly_subplot(
nrows=nplots,
ncols=1,
**{"subplot_titles": titles, "vertical_spacing": 0.15, "shared_xaxes": True},
)
for index, frequency in enumerate(frequencies):
# Calculate the offset frequency
dataframe["Offset " + str(index)] = dataframe[freq_col] - frequency
# Range as a fraction of the center frequency
freq_cutoff = freq_range * frequency
sliced_df = dataframe.loc[
(dataframe["Offset " + str(index)] > -freq_cutoff)
& (dataframe["Offset " + str(index)] < freq_cutoff)
]
# Plot the data
trace = plot_func(
x=sliced_df["Offset " + str(index)], y=sliced_df[int_col], mode="lines"
)
# Plotly indexes from one because they're stupid
fig.add_trace(trace, index + 1, 1)
fig["layout"]["xaxis1"].update(
range=[-freq_cutoff, freq_cutoff],
title="Offset frequency (MHz)",
showgrid=True,
)
fig["layout"]["yaxis" + str(index + 1)].update(showgrid=False)
fig["layout"].update(autosize=True, height=1000, width=900, showlegend=False)
return fig
[docs]def plot_catchirp(chirpdf, catfiles=None):
""" Function to perform interactive analysis with a chirp spectrum, as well
as any reference .cat files you may want to provide.
This is not designed to replace SPECData analysis, but simply to
perform some interactive viewing of the data.
The argument `catfiles` is supplied as a dictionary; where the keys are
the names of the species, and the items are the paths to the .cat files
"""
# Generate the experimental plot first
plots = list()
exp_trace = go.Scattergl(
x=chirpdf["Frequency"], y=chirpdf["Intensity"], name="Experiment"
)
plots.append(exp_trace)
if catfiles is not None:
# Generate the color palette, and remove the alpha value from RGBA
color_palette = generate_colors(len(catfiles))
# Loop over each of the cat files
for color, species in zip(color_palette, catfiles):
species_df = pc.pick_pickett(catfiles[species])
plots.append(
go.Bar(
x=species_df["Frequency"],
y=species_df["Intensity"] / species_df["Intensity"].min(),
name=species,
marker={
# Convert the matplotlib rgb color to hex code
"color": color
},
width=1.0,
opacity=0.6,
yaxis="y2",
)
)
layout = go.Layout(
autosize=False,
height=600,
width=900,
xaxis={"title": "Frequency (MHz)"},
paper_bgcolor="#f0f0f0",
plot_bgcolor="#f0f0f0",
yaxis={"title": ""},
yaxis2={"title": "", "side": "right", "overlaying": "y", "range": [0.0, 1.0]},
)
fig = go.FigureWidget(data=plots, layout=layout)
return fig
[docs]def plot_df(dataframe, cols=None, **kwargs):
""" Function that wraps around the lower level function plot_column.
Will plot every column in a dataframe against the Frequency, unless
specific column names are provided.
Input arguments:
dataframe - pandas dataframe object, with every column as intensity
except "Frequency"
cols - NoneType or tuple-like: if None, every column is plotted.
Alternatively, an iterable is provided to specify which columns are
plotted.
Optional arguments are passed into define_layout, which will define
the axis labels, or into the color map generation
"""
if cols is None:
cols = [key for key in dataframe.keys() if key != "Frequency"]
if len(cols) < 4:
colors = ["#66c2a5", "#fc8d62"]
else:
colors = generate_colors(len(cols), **kwargs)
# Generate the plotly traces
traces = [
plot_column(dataframe, col, color=color) for col, color in zip(cols, colors)
]
layout = define_layout(**kwargs)
# Generate figure object
figure = go.Figure(data=traces, layout=layout)
iplot(figure)
return figure
[docs]def plot_assignment(spec_df, assignments_df, col="Intensity"):
""" Function for plotting spectra with assignments. The assumption is that
the assignments routines are ran prior too this, and so the function
simply takes a dataframe of chosen molecules and plots them alongside
the experimental spectrum, color coded by the molecule
Input argument:
spec_df - dataframe holding the experimental data
assignments_df - dataframe produced from running assignments
"""
# Get a list of unique molecules
molecules = assignments_df["Chemical Name"].unique()
# The ttal number of traces are the number of unique molecules, the traces
# in the experimental data minus the frequency column
nitems = len(molecules) + 1
colors = color_iterator(nitems)
traces = list()
# Loop over the experimental data
traces.append(plot_column(spec_df, col, color=next(colors)))
# Loop over the assignments
for molecule in molecules:
sliced_df = assignments_df.loc[assignments_df["Chemical Name"] == molecule]
traces.append(plot_bar_assignments(sliced_df, next(colors)))
layout = define_layout()
layout["yaxis"] = {"title": "Experimental Intensity"}
# Specify a second y axis for the catalog intensity
layout["yaxis2"] = {
"title": "CDMS/JPL Intensity",
"overlaying": "y",
"side": "right",
"type": "log",
"autorange": True,
}
figure = go.Figure(data=traces, layout=layout)
plot(figure)
return figure
[docs]def generate_colors(n, cmap=plt.cm.Spectral, hex=True):
"""
Generate a linearly spaced color series using a colormap from
Matplotlib. The colors can be returned as either RGB values
or as hex-codes using the `hex` flag.
Parameters
----------
n : int
Number of colours to generate
cmap : str or `matplotlib.colors.LinearSegementedColomap`, optional
Specified colormap to interpolate. If a str is provided, the function
will try to look for it in the available matplotlib colormaps.
hex : bool, optional
If True, hex colors are returned. Otherwise, RGB values.
Returns
-------
colors
List of hex or RGB codes
"""
# In the case that a string is passed, use the `get_cmap` function instead
if type(cmap) == str:
try:
cmap = plt.cm.get_cmap(cmap)
except ValueError:
warn(f"{cmap} not found in Matplotlib, defaulting to Spectral.")
pass
colormap = cmap(np.linspace(0.0, 1.0, n))
if hex is True:
colors = [cl.rgb2hex(color) for color in colormap]
else:
colors = colormap
return colors
[docs]def color_iterator(n, **kwargs):
""" Simple generator that will yield a different color each time.
This is primarily designed when multiple plot types are expected.
Input arguements:
n - number of plots
Optional kwargs are passed into generate_colors
"""
index = 0
colors = generate_colors(n, **kwargs)
while index < n:
yield colors[index]
index += 1
[docs]def plot_bar_assignments(species_df, color="#fc8d62"):
""" Function to generate a bar plot trace for a chemical species.
These plots will be placed in the second y axis!
Input arguments:
species_df - a slice of an assignments dataframe, containing only
one unique species
Optional argument color is a hex code color; if nothing is given
it just defaults to whatever
"""
# We just want one of the molecules, not their life's story
molecule = species_df["Chemical Name"].unique()[0]
trace = go.Bar(
x=species_df["Combined"],
y=10 ** species_df["CDMS/JPL Intensity"],
name=molecule,
text=species_df["Resolved QNs"],
marker={"color": color},
width=0.25,
yaxis="y2",
opacity=0.9,
)
return trace
[docs]def plot_column(dataframe, col, name=None, color=None, layout=None):
""" A low level function for plotting a specific column of
data in a pandas dataframe. This will assume that there
is a column named "Frequency" in the dataframe.
If a layout is not supplied, then the function will
return a Plotly scatter object to be combined with other
data. If a layout is given, then the data will be plot
up directly.
Input arguments:
dataframe - pandas dataframe object
col - str specifying the column used to plot
layout - optional argument; if specified a plotly plot will be
produced.
"""
# If no legend name is provided, use the column
if name is None:
name = col
# Generate the scatter plot
if color is None:
color = "#1c9099"
trace = go.Scatter(
x=dataframe["Frequency"], y=dataframe[col], name=name, marker={"color": color}
)
# If a layout is supplied, plot the figure
if layout:
figure = go.Figure(data=[trace], layout=layout)
iplot(figure)
else:
return trace
[docs]def define_layout(xlabel="", ylabel=""):
""" Function for generating a layout for plotly.
Some degree of customization is provided, but generally sticking
with not having to fuss around with plots.
Input arguments:
x/ylabel - str for what the x and y labels are to be
"""
layout = go.Layout(
xaxis={"title": xlabel, "tickformat": ".,"},
yaxis={"title": ylabel},
autosize=True,
height=650.0,
width=850.0,
paper_bgcolor="#ffffff",
plot_bgcolor="#ffffff",
font=dict(family="Roboto", size=14, color="#000000"),
annotations=list(),
)
return layout
[docs]def save_plot(fig, filename, js=True):
"""
Method for exporting a plotly figure with interactivity.
This method does inject the plotly.js code by default, and so will
result in relatively large files. Use `save_html` instead.
"""
plot(fig, filename=filename, show_link=False, auto_open=False, include_plotlyjs=js)
[docs]def cfa_cmap(nsteps=100):
"""
Generate a Matplotlib colormap with the CfA branding colors. Performs a linear interpolation from the CfA red to
the so-called CfA violet.
Parameters
----------
nsteps - int
Number of steps to take in the interpolation; i.e. number of colors to return
Returns
-------
LinearSegementedColormap instance
"""
colors = [(141, 0, 52), (43, 53, 117)]
cm = cl.LinearSegmentedColormap("cfa", colors, N=nsteps)
return cm
[docs]def pandas_bokeh_table(dataframe, html=False, **kwargs):
"""
Convert a Pandas DataFrame to a Bokeh DataTable object.
Columns will be automatically generated based on the DataFrame keys.
The `html` flag can be used to specify whether or not an HTML representation is returned.
Additional kwargs are passed into the `file_html` function, and is only used when `html`
is True.
Parameters
----------
dataframe : pandas dataframe
Pandas DataFrame to convert into a DataTable object
html : bool, optional
If True, function will return a string of the HTML code for embedding
kwargs
Additional kwargs are passed into the HTML conversion
Returns
-------
DataTable object if html is False, otherwise str
"""
source = ColumnDataSource(dataframe)
columns = [TableColumn(field=key, title=key.capitalize()) for key in dataframe]
table = DataTable(source=source, columns=columns, **kwargs)
if html is True:
return file_html(table, CDN)
else:
return table