import numpy as np
from linescanning import utils
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pandas as pd
import seaborn as sns
import string
from typing import Union
[docs]
class Defaults():
"""Defaults
Default settings for plotting.
Parameters
----------
pad_title: int
Set the distance between the title and the plot. Default = 20
title_size: int
Set the font size of titles. Default = 22; you can set it equal to font_size (or any other size) by specifying "font_size".
font_size: int
Set the font size of axis labels/titles. Default = 18
label_size: int
Set the font size of tick labels. Default = 14
tick_width: float
Set the line-width of the ticks. Default = 0.5
tick_length: float
Set the length of the ticks. Default = 0 (no ticks)
axis_width: float
Set the line-width of axes. Default = 0.5
line_width: int
Line widths for either all graphs (then *int*) or a *list* with the number of elements as requested graphs, default = 1.
line_style: str
Set the style of data in line-plots. Default = "solid"
sns_offset: int
Set the distance between y-axis and start of plot. Default = None
sns_trim: bool
Trim the axes following seaborn's convention. Default = False
sns_bottom: bool
Trim specifically the x-axis of plots. Default = False
sns_ori: str, optional
Default orientation for bar-plots. Default is up-right (vertical). Allowed options are "v" or "h". :class:`linescanning.plotting.LazyBar()`
sns_rot: int, float, optional
Rotation of labels in bar plot. Specific to :class:`linescanning.plotting.LazyBar()`
xkcd: bool
Plot with cartoon style. Default = False
ylim_bottom: float
Set the y-limiter at the bottom of the plot. Default = None
ylim_top: float
Set the y-limiter at the top of the plot. Default = None
xlim_left: float
Set the x-limiter at the left of the plot. Default = None
xlim_right: float
Set the x-limiter at the right of the plot. Default = None
set_xlim_zero: bool
Set the distance between plot and y-axis to 0. Default = False
legend_handletext: float
Set the distance between the handle and text in legends. Default = 0.05 (bit closer than default)
x_label: str, optional
Label of x-axis, by default None
y_label: str, optional
Label of y-axis, by default None
labels: str, list, optional
String (if 1 timeseries) or list (with the length of `ts`) of colors, by default None. Labels for the timeseries to be used in the legend
title: str, dict, optional
String of dictionary collecting the following keys representing information about the title:
>>> title = {
>>> 'title' "some title", # title text
>>> 'color': 'k', # color (default = 'k')
>>> 'fontweight': "bold" # fontweight (default = 'normal'), can be any of the matplotib fontweight options (e.g., 'italic', 'bold', 'normal' etc.)
>>> }
color: str, list, optional
String (if 1 timeseries) or list (with the length of `ts`) of colors, by default None. If nothing is specified, we'll use `cmap` to create a color palette
save_as: str, list, optional
Save the plot, by default None. If you want to use figures in Inkscape, save them as PDFs to retain high resolution; specify a list of strings to save the plot with multiple extensions
y_lim: list, optional
List for `self._set_ylim`
x_lim: list, optional
List for `self_.set_xlim`
x_dec: int, optional
Enforce `x_ticks` to have `x_dec` decimal accuracy. Default is whatever the data dictates
y_dec: int, optional
Enforce `y_ticks` to have `y_dec` decimal accuracy. Default is whatever the data dictates
add_hline: dict, optional
Dictionary for a horizontal line through the plot, by default None. Collects the following items:
>>> add_hline = {
>>> 'pos' 0, # position
>>> 'color': 'k', # color
>>> 'lw': 1, # linewidth
>>> 'ls': '--' # linestyle
>>> }
You can get the settings above by specifying *add_hline='default'*. Now also accepts *add_hline='mean'* for single inputs
add_vline: dict, optional
Dictionary for a vertical line through the plot, by default None. Same keys as `add_hline`
dpi: int, optional
Save figures with DPI-value. Default is 300
figure_background_color: str, optional
Background of images. Default is "white"
bbox_inches: str, optional
Bounding box settings. Default is "tight"
"""
def __init__(self, **kwargs):
self.ls_kwargs = [
"pad_title",
"font_size",
"title_size",
"label_size",
"tick_width",
"tick_length",
"axis_width",
"line_width",
"line_style",
"sns_offset",
"sns_trim",
"sns_bottom",
"sns_ori",
"sns_rot",
"sns_despine",
"xkcd",
"ylim_bottom",
"ylim_top",
"xlim_left",
"xlim_right",
"set_xlim_zero",
"legend_handletext",
"x_label",
"y_label",
"title",
"save_as",
"y_lim",
"x_lim",
"x_ticks",
"y_ticks",
"z_ticks",
"x_ticklabels",
"y_ticklabels",
"z_ticklabels",
"axs",
"color",
"y_dec",
"x_dec",
"add_vline",
"add_hline",
"dpi",
"figure_background_color",
"bbox_inches",
"fontname",
"legend_kwargs"
]
self.pad_title = 20
self.font_size = 18
self.title_size = 24
self.label_size = 14
self.tick_width = 0.5
self.tick_length = 0
self.axis_width = 0.5
self.line_width = 1
self.line_style = "solid"
self.sns_offset = None
self.sns_trim = False
self.sns_bottom = False
self.sns_ori = "v"
self.sns_rot = None
self.xkcd = False
self.ylim_bottom = None
self.ylim_top = None
self.xlim_left = None
self.xlim_right = None
self.set_xlim_zero = False
self.legend_handletext = 0.25
self.title = None
self.save_as = None
self.axs = None
self.color = None
self.add_vline = None
self.add_hline = None
self.dpi = 300
self.figure_background_color = "white"
self.bbox_inches = "tight"
self.fontname = None
self.sns_despine = True
self.legend_kwargs = {}
for i in ["label","lim","ticks","ticklabels","dec"]:
for ii in ["x","y","z"]:
setattr(self, f"{ii}_{i}", None)
# update kwargs
self.__dict__.update(kwargs)
# set default font
if self.xkcd:
self.fontname = "Humor Sans"
else:
if not isinstance(self.fontname, str):
self.fontname = "Montserrat"
# update font widely
self.update_rc(self.fontname)
[docs]
def update_rc(self, font):
"""update font"""
plt.rcParams.update({'font.family': font})
def _set_figure_axs(self, figsize=None):
if not isinstance(self.axs, mpl.axes._axes.Axes):
if not isinstance(figsize, tuple):
figsize = self.figsize
_, self.axs = plt.subplots(figsize=figsize)
def _set_spine_width(self, ax):
"""update spine width"""
for axis in ['top', 'bottom', 'left', 'right']:
ax.spines[axis].set_linewidth(self.axis_width)
def _set_axlabel(self, ax, lbl, axis="x", **kwargs):
"""set y-label"""
if hasattr(ax, f"set_{axis}label"):
if axis == "x":
ffunc = ax.set_xlabel
elif axis == "y":
ffunc = ax.set_ylabel
else:
ffunc = ax.set_zlabel
if isinstance(lbl, (str,list)):
ffunc(
lbl,
fontsize=self.font_size,
fontname=self.fontname,
**kwargs
)
def _set_tick_params(self, ax, **kwargs):
"""set width/length/labelsize of ticks"""
ax.tick_params(
width=self.tick_width,
length=self.tick_length,
labelsize=self.label_size,
**kwargs)
def _set_title(self, ax, title, **kwargs):
"""set title of plot"""
if isinstance(self.title_size, str):
if hasattr(self, self.title_size):
self.title_size = getattr(self, self.title_size)
if isinstance(title, (str,dict)):
default_dict = {
'color': 'k',
'fontweight': 'normal'
}
if isinstance(title, str):
title_dict = {"title": self.title}
elif isinstance(title, dict):
title_dict = title.copy()
else:
raise ValueError(f"title input must be a string or dictionary, not {type(title)}: '{title}'")
# add default keys if they're missing in dictionary
for key in list(default_dict.keys()):
if key not in list(title_dict.keys()):
title_dict[key] = default_dict[key]
ax.set_title(
title_dict["title"],
color=title_dict["color"] ,
fontweight=title_dict["fontweight"],
fontname=self.fontname,
fontsize=self.title_size,
pad=self.pad_title,
**kwargs)
def _set_bar_lim(self, ax, lim):
if isinstance(lim, list):
if self.sns_ori == 'h':
fc = self._set_xlim
elif self.sns_ori == "v":
fc = self._set_ylim
else:
raise ValueError(f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
# set
fc(ax, lim)
def _set_bar_ticks(self, ax, ticks):
if isinstance(ticks, list):
if self.sns_ori == 'h':
fc = "x"
elif self.sns_ori == "v":
fc = "y"
else:
raise ValueError(f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
# set
self._set_ticks(ax, ticks, axis=fc)
@staticmethod
def _set_ticks(ax, ticks, axis="x"):
"""set ticks"""
if hasattr(ax, f"set_{axis}ticks"):
if axis == "x":
ffunc = ax.set_xticks
elif axis == "y":
ffunc = ax.set_yticks
else:
ffunc = ax.set_zticks
if isinstance(ticks, (pd.Series,pd.DataFrame)):
ticks = ticks.values
if isinstance(ticks, np.ndarray):
ticks = [float(i) for i in ticks]
if isinstance(ticks, (list)):
# check if elements are output of get_?ticklabels()
if all([isinstance(i, mpl.text.Text) for i in ticks]):
ticks = [float(i._text) for i in ticks]
ffunc(ticks)
@staticmethod
def _set_ticklabels(ax, ticks, axis="x", **kwargs):
if hasattr(ax, f"set_{axis}ticklabels"):
if axis == "x":
ffunc = ax.set_xticklabels
elif axis == "y":
ffunc = ax.set_yticklabels
else:
ffunc = ax.set_zticklabels
if isinstance(ticks, (pd.Series,pd.DataFrame)):
ticks = ticks.values
if isinstance(ticks, np.ndarray):
ticks = [float(i) for i in ticks]
if isinstance(ticks, list):
ffunc(ticks, **kwargs)
@staticmethod
def _set_ylim(ax,lim):
"""set y-limit"""
if isinstance(lim, list):
ax.set_ylim(lim)
elif isinstance(lim, (int,float)):
ax.set_ylim(top=lim)
@staticmethod
def _set_xlim(ax,lim):
"""set x-limit"""
if isinstance(lim, list):
ax.set_xlim(lim)
elif isinstance(lim, (int,float)):
ax.set_xlim(left=lim)
def _despine(self, ax, **kwargs):
"""despine plot"""
if self.sns_despine:
sns.despine(
ax=ax,
offset=self.sns_offset,
trim=self.sns_trim,
**kwargs
)
@staticmethod
def _set_ticker(ax, dec, axis="x"):
"""set all y-ticks to decimal"""
if hasattr(ax, f"{axis}axis"):
if axis == "x":
ffunc = ax.xaxis
elif axis == "y":
ffunc = ax.yaxis
else:
ffunc = ax.zaxis
if isinstance(dec, int):
from matplotlib.ticker import FormatStrFormatter
ffunc.set_major_formatter(FormatStrFormatter(f"%.{dec}f"))
def _set_shaded_error(
self,
x: np.ndarray=None,
tc: np.ndarray=None,
ax: mpl.axes._axes.Axes=None,
yerr: np.ndarray=None,
**kwargs):
if isinstance(yerr, (int,float,list,np.ndarray)):
if np.isscalar(yerr) or len(yerr) == len(tc):
if not np.isscalar(yerr):
if all([np.isnan(i) for i in yerr]):
raise TypeError(f"Error contains only NaNs")
ymin = tc - yerr
ymax = tc + yerr
elif len(yerr) == 2:
ymin, ymax = yerr
ax.fill_between(
x,
ymax,
ymin,
**kwargs)
def _draw_errorbar(
self,
x: np.ndarray=None,
y: np.ndarray=None,
ax: mpl.axes._axes.Axes=None,
yerr: np.ndarray=None,
xerr: np.ndarray=None,
**kwargs):
# set bunch of defaults
if not "linewidth" in list(kwargs.keys()) and not "lw" in list(kwargs.keys()):
kwargs["lw"] = self.line_width
if not "color" in list(kwargs.keys()):
kwargs["ecolor"] = "k"
if not "zorder" in list(kwargs.keys()):
kwargs["zorder"] = 0
if not "linestyle" in list(kwargs.keys()) and not "ls" in list(kwargs.keys()):
kwargs["ls"] = "None"
if isinstance(x, (pd.Series,pd.DataFrame)):
x = x.values
if isinstance(y, (pd.Series,pd.DataFrame)):
y = y.values
if isinstance(yerr, (pd.Series,pd.DataFrame)):
yerr = yerr.values
if isinstance(xerr, (pd.Series,pd.DataFrame)):
xerr = xerr.values
ax.errorbar(
x,
y,
yerr=yerr,
xerr=xerr,
**kwargs)
def _set_legend_labels(self, ax, labels=None):
if isinstance(labels, (list,np.ndarray)):
if not "font_size" in list(self.legend_kwargs.keys()) and not "fontsize" in list(self.legend_kwargs.keys()):
self.legend_kwargs["fontsize"] = self.label_size
ax.legend(
frameon=False,
handletextpad=self.legend_handletext,
**self.legend_kwargs,
)
def _save_as(self, save_as, **kwargs):
"""simple save function"""
if isinstance(save_as, str):
plt.savefig(
save_as,
bbox_inches=self.bbox_inches,
dpi=self.dpi,
facecolor=self.figure_background_color,
**kwargs
)
def _save_figure(self, save_as):
"""save same figure with multiple extensions"""
if isinstance(save_as, (list,str)):
if isinstance(save_as, str):
save_as = [save_as]
for ii in save_as:
self._save_as(ii)
def _return_element(self, ddict, el, ix=0):
if isinstance(ddict[el], list):
return ddict[el][ix]
else:
return ddict[el]
def _add_line(
self,
ax=None,
**kwargs
):
for ii in ["hline","vline"]:
line_kw = {}
test_attr = getattr(self, f"add_{ii}")
if isinstance(test_attr, (float,int,dict,str,list)):
add_line = True
# define default dictionary
default_dict = {
'color': 'k',
'ls': 'dashed',
'lw': 0.5
}
# set fixer depending on line being drawn
if ii == "hline":
default_dict["min"] = 0
default_dict["max"] = 1
else:
default_dict["min"] = 0
default_dict["max"] = 1
# add lines
if test_attr == "default":
test_attr = {'pos': 0}
elif isinstance(test_attr, (float,int,list,np.ndarray)):
test_attr = {"pos": test_attr}
elif isinstance(test_attr, dict):
add_line = True
else:
add_line = False
if add_line:
for key in list(default_dict.keys()):
if key not in list(test_attr.keys()):
test_attr[key] = default_dict[key]
# enforce list so we only need to call functions once
if not "pos" in list(test_attr.keys()):
raise ValueError(f"Need the 'pos' key to denote position..")
else:
if isinstance(test_attr['pos'], (int,float)):
test_attr['pos'] = [test_attr['pos']]
# loop through elements
if isinstance(test_attr['pos'], (list,np.ndarray)):
for ix,line in enumerate(test_attr['pos']):
# define function
if ii == "hline":
ffunc = ax.axhline
kwargs_list = ["color","lw","ls","xmin","xmax"]
else:
ffunc = ax.axvline
kwargs_list = ["color","lw","ls","ymin","ymax"]
# update kwargs
for key,val in zip(
["color","lw","ls","min","max"],
kwargs_list
):
line_kw[val] = self._return_element(test_attr, key, ix=ix)
# run func
ffunc(line, **line_kw, **kwargs)
[docs]
class LazyPRF(Defaults):
"""LazyPRF
Plot the geometric location of the Gaussian pRF.
Parameters
----------
prf: numpy.ndarray
instantiation of `gauss2D_iso_cart`; will be np.squeeze'ed over the first axis if `ndim >= 3`.
vf_extent: list
the space the pRF lives in
cmap: str, optional
Colormap for imshow; accepts output from :func:`linescanning.utils.make_binary_cm`. Defaults to 'magma'
cross_color: str, optional
Color for the fixation cross; defaults to 'white'. You can set it to 'k' if you have a binary colormap as input
alpha: float, optional
Opacity for imshow
shrink_factor: float, optional
When the background of the image is white, we create a black border around the Circle patch. If this is equal to `vf_extent`, the border is cut off at some points. This factor shrinks the radius of the Circle, so that we can have a nice border. When set to 0.9, it becomes sort of like a target. This is relevant for **all** non-`magma` color maps that you insert, specifically a :func:`linescanning.utils.make_binary_cm` object
full_axis: bool, optional
If `True`, the entire axis of `vf_extent` will be used for the ticks (more detailed). If `False`, a truncated/trimmed version will be returned (looks cleaner). Default = False
axis_off: bool, optional
If `True` the x/y axis will be maintained, and the `vf_extent` will be given as ticks. If `False`, axis will be turned off. If `axis_off=True`, then `full_axis` and other label/axis parameters are ignored. Default = True
vf_only: bool, optional
Only show the outline of the the visual field, without pRF. You still need to specify the pRF as we'll `imshow` an empty array with the same shape rather than the pRF. Default = False
line_width: float, optional
Width of the outer border of the visual field if `cmap` is not *viridis* or *magma* (these color maps are quite default, and do not require an extra border like :func:`linescanning.utils.make_binary_cm`-objects do). Default is 0.5.
cross_width: float, optional
Width of the cross denoting the x/y axis. Default is 0.5, but can be increased if `cmap` is not *viridis* or *magma* to enhance visibility
z_lines: int, optional
Set the order of the vertical/horizontal lines. Default is **on top** of the pRF (1)
z_prf: int, optional
Set the order of the pRF imshow. Default is below the axis lines, but can be changed to be on top of them. Default = 0
imshow_kw: dict, optional
Additional kwargs passed on to `imshow`
Returns
----------
matplotlib.pyplot plot
"""
def __init__(
self,
prf,
vf_extent,
cmap='RdBu_r',
cross_color="white",
alpha=None,
shrink_factor=1,
axis_off=True,
figsize=(8,8),
full_axis=False,
vf_only=False,
cross_width=0.5,
concentric=None,
z_lines=1,
z_prf=0,
edge_color=None,
imshow_kw={},
**kwargs):
self.prf = prf
self.vf_extent = vf_extent
self.cmap = cmap
self.cross_color = cross_color
self.alpha = alpha
self.shrink_factor = shrink_factor
self.axis_off = axis_off
self.figsize = figsize
self.full_axis = full_axis
self.vf_only = vf_only
self.cross_width = cross_width
self.concentric = concentric
self.z_lines = z_lines
self.z_prf = z_prf
self.edge_color = edge_color
self.imshow_kw = imshow_kw
super().__init__()
self.__dict__.update(kwargs)
self.update_rc(self.fontname)
if self.xkcd:
with plt.xkcd():
self.plot()
else:
self.plot()
# save
self._save_figure(self.save_as)
def plot(self):
# set figure axis
self._set_figure_axs()
if self.prf.ndim >= 3:
self.prf = np.squeeze(self.prf, axis=0)
if self.alpha == None:
self.alpha = 1
# add cross-hair
for ii in ["hline","vline"]:
self.line_kw = {
"pos": 0,
"color": self.cross_color,
"lw": self.cross_width
}
setattr(self, f"add_{ii}", self.line_kw)
self._add_line(
self.axs,
zorder=self.z_lines
)
if not self.vf_only:
plot_obj = self.prf
else:
plot_obj = np.zeros_like(self.prf)
# check if pRF has negatives
if plot_obj.min() < 0:
vmin = plot_obj.min()
vmax = -plot_obj.min()
else:
vmin = -plot_obj.max()
vmax = plot_obj.max()
if len(self.vf_extent) < 4:
self.use_extent = self.vf_extent+self.vf_extent
else:
self.use_extent = self.vf_extent
im = self.axs.imshow(
plot_obj,
extent=self.use_extent,
cmap=self.cmap,
alpha=self.alpha,
zorder=self.z_prf,
vmin=vmin,
vmax=vmax,
**self.imshow_kw)
# In case of a white background, the circle for the visual field is cut off, so we need to make an adjustment:
if self.cmap != 'magma' and self.cmap != 'viridis':
radius = self.use_extent[-1]*self.shrink_factor
else:
radius = self.use_extent[-1]
# set title
self._set_title(self.axs, self.title)
# set patch
self.patch = patches.Circle(
(0,0),
radius=radius,
transform=self.axs.transData,
edgecolor=self.edge_color,
facecolor="None",
linewidth=self.line_width)
self.axs.add_patch(self.patch)
im.set_clip_path(self.patch)
if self.axis_off:
self.axs.axis('off')
else:
# set tick params
self._set_tick_params(self.axs)
# set spine widths
self._set_spine_width(self.axs)
if self.full_axis:
self.use_ticks = np.arange(self.vf_extent[0],self.vf_extent[1]+1, 1)
else:
self.use_ticks = self.vf_extent
# set ticks
loop_funcs = [
"_set_ticks",
"_set_ticklabels",
"_set_ticker"
]
for x in ["x","y"]:
for ff,el in zip(
loop_funcs,
["ticks","ticklabels","dec"]):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self, ff)(self.axs, add_to_ax, axis=x)
self._despine(self.axs)
[docs]
class LazyPlot(Defaults):
"""LazyPlot
Class for plotting because I'm lazy and I don't want to go through the ``matplotlib`` motion everything I quickly want to visualize something. This class makes that a lot easier. It allows single inputs, lists with multiple timecourses, labels, error shadings, and much more.
Parameters
----------
ts: list, numpy.ndarray
Input data. Can either be a single list, or a list of multiple numpy arrays. If you want labels, custom colors, or error bars, these inputs must come in lists of similar length as ``ts``!
xx: list, numpy.ndarray, optional
X-axis array
error: list, numpy.ndarray, optional
Error data with the same length/shape as the input timeseries, by default None. Can be either a numpy.ndarray for 1 timeseries, or a list of numpy.ndarrays for multiple timeseries
error_alpha: float, optional
Opacity level for error shadings, by default 0.3
cmap: str, optional
Color palette to use for colors if no individual colors are specified, by default 'viridis'
figsize: tuple, optional
Figure dimensions as per usual matplotlib conventions, by default (25,5)
markers: str, list, optional
Use markers during plotting. A single option will be applied to all elements in ``ts``. A separate list for each element of ``ts`` is also accepted. If one array in ``ts`` should not have markers, use `None`. E.g., if `len(ts) == 3`, and we want only the first timecourse to have markers use: ``markers=['.',None,None]``
markersize: str, list, optional
Specify marker sizes during plotting. A single option will be applied to all elements in ``ts``. A separate list for each element of ``ts`` is also accepted. If one array in ``ts`` should not have markers, use `None`. E.g., if ``len(ts) == 3``, and we want only the first timecourse to have markers use: ``markers=['.',None,None]``
markerfc: str, list, optional
Specify marker facecolor during plotting. A single option will be applied to all elements in ``ts``. A separate list for each element of ``ts`` is also accepted
markerec: str, list, optional
Specify marker edgecolor during plotting. A single option will be applied to all elements in ``ts``. A separate list for each element of ``ts`` is also accepted
plot_kw: dict, optional
Further plotting options passed to ``matplotlib.pyplot.plot``, such as ``markerfacecolor`` (removes fill from markers, leaving only the edges), and ``markeredgewidth`` (sets the width of the marker outline). These options are passed to **every** elements in ``ts``.
x_ticks: list, optional
Locations where to put the ticks on the x-axis
y_ticks: list, optional
Locations where to put the ticks on the y-axis
Example
----------
>>> # create a bunch of timeseries
>>> from linescanning import utils
>>> ts = utils.random_timeseries(1.2, 0.0, 100)
>>> ts1 = utils.random_timeseries(1.2, 0.3, 100)
>>> ts2 = utils.random_timeseries(1.2, 0.5, 100)
>>> ts3 = utils.random_timeseries(1.2, 0.8, 100)
>>> ts4 = utils.random_timeseries(1.2, 1, 100)
>>> # plot 1 timecourse
>>> plotting.LazyPlot(ts2, figsize=(20, 5))
<linescanning.plotting.LazyPlot at 0x7f839b0289d0>
>>> # plot multiple timecourses, add labels, and save file
>>> plotting.LazyPlot([ts, ts1, ts2, ts3, ts4], figsize=(20, 5), save_as="test_LazyPlot.pdf", labels=['vol=0', 'vol=0.3', 'vol=0.5', 'vol=0.8', 'vol=1.0'])
<linescanning.plotting.LazyPlot at 0x7f839b2177c0>
>>> # add horizontal line at y=0
>>> hline = {'pos': 0, 'color': 'k', 'lw': 0.5, 'ls': '--'}
>>> >>> plotting.LazyPlot(ts2, figsize=(20, 5), add_hline=hline)
<linescanning.plotting.LazyPlot at 0x7f839b053580>
>>> # add shaded error bars
>>> from scipy.stats import sem
# make some stack
>>> stack = np.hstack((ts1[...,np.newaxis],ts2[...,np.newaxis],ts4[...,np.newaxis]))
>>> avg = stack.mean(axis=-1) # calculate mean
>>> err = sem(stack, axis=-1) # calculate error
>>> plotting.LazyPlot(avg, figsize=(20, 5), error=err)
<linescanning.plotting.LazyPlot at 0x7f839b0d5220>
Notes
----------
See https://linescanning.readthedocs.io/en/latest/examples/lazyplot.html for more examples
"""
def __init__(
self,
ts,
xx=None,
error=None,
error_alpha=0.3,
figsize=(14,4),
cmap='viridis',
labels=None,
markers=None,
markersize=None,
markerfc=None,
markerec=None,
plot_alpha=None,
plot_kw={},
**kwargs):
self.array = ts
self.xx = xx
self.error = error
self.error_alpha = error_alpha
self.plot_alpha = plot_alpha
self.figsize = figsize
self.cmap = cmap
self.labels = labels
self.markers = markers
self.markersize = markersize
self.markerfc = markerfc
self.markerec = markerec
self.plot_kw = plot_kw
super().__init__()
self.__dict__.update(kwargs)
self.update_rc(self.fontname)
# plot
if self.xkcd:
with plt.xkcd():
self.plot()
else:
self.plot()
# save
self._save_figure(self.save_as)
[docs]
def plot(self):
"""main plotting function"""
# set figure axis
self._set_figure_axs()
# sort out color
if isinstance(self.array, np.ndarray):
self.array = [self.array]
if not self.color:
self.color = sns.color_palette(self.cmap, 1)[0]
else:
self.color = [self.color]
# check if alpha's match nr of elements in array
if isinstance(self.array, list):
if not isinstance(self.plot_alpha, list):
if self.plot_alpha == None:
self.plot_alpha = [1 for _ in range(len(self.array))]
elif isinstance(self.plot_alpha, (int,float)):
self.plot_alpha = [self.plot_alpha for _ in range(len(self.array))]
else:
self.plot_alpha = [self.plot_alpha]
if len(self.plot_alpha) != len(self.array):
raise ValueError(f"Alpha list ({len(self.plot_alpha)}) does not match length of data list ({len(self.array)})")
if isinstance(self.color, str):
self.color = [self.color for _ in range(len(self.array))]
if not isinstance(self.markers, list):
if self.markers == None:
self.markers = [None for _ in range(len(self.array))]
else:
self.markers = [self.markers for _ in range(len(self.array))]
if len(self.markers) != len(self.array):
raise ValueError(f"Marker list ({len(self.markers)}) does not match length of data list ({len(self.array)})")
self.fix_error = True
if not isinstance(self.error, (int,float,np.ndarray,list)):
self.fix_error = False
if not isinstance(self.error, list):
if not isinstance(self.error, (int,float,np.ndarray,list)):
self.error = [None for ii in range(len(self.array))]
elif isinstance(self.error, (int,float)):
self.error = [self.error for ii in range(len(self.array))]
else:
self.error = [self.error]
# filter out all NaNs
if self.fix_error:
self.tmp_error = [None for ii in range(len(self.array))]
for ix,err in enumerate(self.error):
if isinstance(err, (int,float)):
self.tmp_error[ix] = err
else:
if not all([np.isnan(i) for i in err]):
self.tmp_error[ix] = err
self.error = self.tmp_error.copy()
if len(self.error) != len(self.array):
raise ValueError(f"Error list ({len(self.error)}) does not match length of data list ({len(self.array)})")
if not isinstance(self.markersize, list):
if self.markersize == None:
self.markersize = [None for ii in range(len(self.array))]
else:
self.markersize = [self.markersize]
if len(self.markersize) != len(self.array):
raise ValueError(f"Markersize list ({len(self.markersize)}) does not match length of data list ({len(self.array)})")
# decide on color scheme
if not isinstance(self.color, list):
self.color_list = sns.color_palette(self.cmap, len(self.array))
else:
self.color_list = self.color
if len(self.color_list) != len(self.array):
raise ValueError(f"Length color list ({len(self.color_list)}) does not match length of data list ({len(self.array)})")
for idx,el in enumerate(self.array):
# squeeze dimensions
if el.ndim > 1:
el = el.squeeze()
# decide on line-width
if isinstance(self.line_width, list):
if len(self.line_width) != len(self.array):
raise ValueError(f"Length of line width lenghts {len(self.line_width)} does not match length of data list ({len(self.array)}")
use_width = self.line_width[idx]
elif isinstance(self.line_width, (int,float)):
use_width = self.line_width
else:
use_width = ""
# decide on line-style
if isinstance(self.line_style, list):
if len(self.line_style) != len(self.array):
raise ValueError(f"Length of line width lenghts {len(self.line_style)} does not match length of data list ({len(self.array)}")
use_style = self.line_style[idx]
elif isinstance(self.line_style, str):
use_style = self.line_style
else:
use_style = "solid"
# decide on x-axis
if not isinstance(self.xx, (np.ndarray,list,range,pd.DataFrame,pd.Series)):
self.t_ = np.arange(0, len(el))
else:
# range has no copy attribute
if isinstance(self.xx, range):
self.t_ = self.xx
elif isinstance(self.xx, (pd.DataFrame,pd.Series)):
self.t_ = self.xx.values
elif isinstance(self.xx, list):
# received array-specific list
if len(self.xx) == len(self.array):
self.t_ = self.xx[idx]
else:
self.t_ = self.xx
elif isinstance(self.xx, np.ndarray):
if self.xx.ndim>1:
self.xx = self.xx.squeeze()
self.t_ = self.xx.copy()
else:
raise TypeError(f"Could not derive time axis from input type {type(self.xx)}")
if isinstance(self.labels, (list,np.ndarray)):
lbl = self.labels[idx]
else:
lbl = None
if isinstance(self.markerec, (list,str)):
if isinstance(self.markerec, list):
use_ec = self.markerec[idx]
else:
use_ec = self.markerec
self.plot_kw["markeredgecolor"] = use_ec
if isinstance(self.markerfc, (list,str)):
if isinstance(self.markerfc, list):
use_fc = self.markerfc[idx]
else:
use_fc = self.markerfc
self.plot_kw["markerfacecolor"] = use_fc
# plot
self.axs.plot(
self.t_,
el,
color=self.color_list[idx],
label=lbl,
lw=use_width,
ls=use_style,
marker=self.markers[idx],
markersize=self.markersize[idx],
alpha=self.plot_alpha[idx],
**self.plot_kw)
# plot shaded error bars
self._set_shaded_error(
x=self.t_,
ax=self.axs,
tc=el,
yerr=self.error[idx],
color=self.color_list[idx],
alpha=self.error_alpha
)
# axis labels and titles
self._set_legend_labels(self.axs, labels=self.labels)
# set title
self._set_title(self.axs, self.title)
# set tick params
self._set_tick_params(self.axs)
# set spine widths
self._set_spine_width(self.axs)
# give priority to specify x-lims rather than seaborn's xlim
if not self.x_lim:
if isinstance(self.xlim_left, (float,int)):
self.axs.set_xlim(left=self.xlim_left)
else:
self.axs.set_xlim(left=self.t_[0])
if self.xlim_right:
self.axs.set_xlim(right=self.xlim_right)
else:
self.axs.set_xlim(right=self.t_[-1])
else:
self.axs.set_xlim(self.x_lim)
if not self.y_lim:
if isinstance(self.ylim_bottom, (float,int)):
self.axs.set_ylim(bottom=self.ylim_bottom)
if self.ylim_top:
self.axs.set_ylim(top=self.ylim_top)
else:
self.axs.set_ylim(self.y_lim)
# set ticks
loop_funcs = [
"_set_ticks",
"_set_ticklabels",
"_set_ticker",
"_set_axlabel"
]
for x in ["x","y"]:
for ff,el in zip(
loop_funcs,
["ticks","ticklabels","dec","label"]
):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self, ff)(self.axs, add_to_ax, axis=x)
# draw horizontal/vertical lines with ax?line
self._add_line(ax=self.axs)
# despine
self._despine(self.axs)
[docs]
class LazyCorr(Defaults):
"""LazyCorr
Wrapper around seaborn's regplot. Plot data and a linear regression model fit. In addition to creating the plot, you can also run a regression or correlation using pingouin by setting the corresponding argument to ``True``.
Parameters
----------
data: pd.DataFrame, optional
Input DataFrame. In this case, use strings representing column names for ``x``, ``y``, and ``color_by``. Internally, the dataframe is parsed into arrays so that it's compatible with matplotlib's scatter-/ color-by functions
x: str, np.ndarray, list
First variable to include in plot/regression. Can be a list/array representing data, or a column name from ``data``
y: str, np.ndarray, list
Second variable to include in plot/regression. Can be a list/array representing data, or a column name from ``data``
color_by: str, np.ndarray, list
Color the points according to a separate array. Can be a list/array representing data, or a column name from ``data``. Default color map for this is 'viridis', and can be changed by passing arguments to `scatter_kwargs`
color: str, list, optional
String representing a color, by default "#ccccccc" to color the regression fit
figsize: tuple, optional
Figure dimensions as per usual matplotlib conventions, by default (8,8)
axs: <AxesSubplot:>, optional
Matplotlib axis to store the figure on
correlation: bool, optional
Run a correlation between ``x`` and ``y``. The result is stored in ``self.correlation_result``
regression: bool, optional
Run a regression between ``x`` and ``y``. The result is stored in ``self.regression_result``
scatter_kwargs: dict, optional
Additional options passed on to the ``scatter`` function from matplotlib. Set colorbar to nothing by passing:
>>> scatter_kwargs={"cbar": False}
result_to_plot: bool, optional
Add the correlation/regression result to the plot
stat_kwargs: dict, optional
Options passed on to pingouin's stats functions
reg_kwargs: dict, optional
Options passed on the seaborn's regplot
Example
----------
>>> from linescanning import plotting
>>> import matplotlib.pyplot as plt
>>> # vanilla version; here, the regression fit has the same color as the dots.
>>> fig,axs = plt.subplots(figsize=(7,7))
>>> plotting.LazyCorr(
>>> x_data,
>>> y_data,
>>> axs=axs,
>>> x_label="add xlabel",
>>> y_label="add ylabel")
>>> # more exotic version: color each dot differently
>>> from linescanning import utils
>>> #
>>> fig,axs = plt.subplots(figsize=(7,7))
>>> #
>>> # create color map between red and blue; return as list
>>> colors = utils.make_between_cm(["r","b], as_list=True, N=len(y_data))
>>> for ix,val in enumerate(y_data):
>>> axs.plot(x_data[ix], val, 'o', color=colors[ix], alpha=0.6)
>>> #
>>> #add the regression fit
>>> plotting.LazyCorr(
>>> x_data,
>>> y_data,
>>> axs=axs,
>>> add_points=False, # turn off points; we've already plotted them
>>> x_label="add xlabel",
>>> y_label="add ylabel")
Notes
----------
see documentation of :class:`linescanning.plotting.Defaults()` for formatting options
"""
def __init__(
self,
data: pd.DataFrame=None,
x: Union[str,np.ndarray]=None,
y: Union[str,np.ndarray]=None,
color: str="#cccccc",
figsize: tuple=(7,7),
points: bool=True,
label: str=None,
scatter_kwargs: dict={},
stat_kwargs: dict={},
color_by: Union[list,np.ndarray]=None,
regression: bool=False,
correlation: bool=False,
reg_kwargs: dict={},
error_kwargs: dict={},
result_to_plot: bool=False,
result_loc: tuple=(0.55,0.1),
result_ec: tuple=(1.,0.5,0.5),
result_fc: tuple=(1.,0.8,0.8),
result_dec: int=2,
verbose=False,
*args,
**kwargs):
# init default plotter class
super().__init__(**kwargs)
self.data = data
self.x = x
self.y = y
self.color = color
self.figsize = figsize
self.points = points
self.label = label
self.scatter_kwargs = scatter_kwargs
self.stat_kwargs = stat_kwargs
self.color_by = color_by
self.regression = regression
self.correlation = correlation
self.reg_kwargs = reg_kwargs
self.error_kwargs = error_kwargs
self.result_to_plot = result_to_plot
self.result_loc = result_loc
self.result_fc = result_fc
self.result_ec = result_ec
self.result_dec = result_dec
self.verbose = verbose
if self.xkcd:
with plt.xkcd():
self.plot()
else:
self.plot()
# run quick regression with pingouin
if self.regression:
self._run_regression()
# run quick correlation with pingouin
if self.correlation:
self._run_correlation()
# print results?
if self.verbose:
self.print_results()
# check if we should add text box
if self.result_to_plot:
self.add_result_to_plot(*args)
# save
self._save_figure(self.save_as)
def _get_result(self):
if hasattr(self, "correlation_result"):
met_name = "r"
p_name = "p-unc"
test = "correlation"
result_obj = self.correlation_result
meth = result_obj["method"].values[0]
elif hasattr(self, "regression_result"):
met_name = "r2"
p_name = "pval"
test = "regression"
result_obj = self.regression_result
meth = ""
p_val = result_obj[p_name][0]
met_val = result_obj[met_name][0]
return {
"p": p_val,
"test": test,
"metric": met_name,
"value": met_val,
"method": meth
}
def print_results(self, return_result=False):
res = self._get_result()
if isinstance(self.x_col, str) and isinstance(self.y_col, str):
col_txt = f"\tA={self.x_col}\tB={self.y_col}"
else:
col_txt = ""
if len(res['method'])>0:
meth_txt = f" ({res['method']})"
else:
meth_txt = ""
utils.verbose(f"Test={res['test']}{meth_txt}{col_txt} | {res['metric']}={round(res['value'],self.result_dec)},\tp={round(res['p'],self.result_dec)}", True)
if return_result:
return res
def add_result_to_plot(self):
res = self.print_results(return_result=True)
self.axs.text(
*self.result_loc,
f"{res['metric']}={round(res['value'],self.result_dec)}, p={round(res['p'],self.result_dec)}",
size=self.font_size*0.8,
bbox=dict(boxstyle="round", ec=self.result_ec,fc=self.result_fc),
transform=self.axs.transAxes
)
def _run_regression(self):
try:
import pingouin as pg
except:
raise ImportError("Could not import pingouin, so this functionality is not available")
self.regression_result = pg.linear_regression(
self.x,
self.y,
remove_na=True,
**self.stat_kwargs
)
def _run_correlation(self):
try:
import pingouin as pg
except:
raise ImportError("Could not import pingouin, so this functionality is not available")
# convert to dataframe
self.data = pd.DataFrame({"x": self.x, "y": self.y})
self.x = "x"
self.y = "y"
self.correlation_result = pg.pairwise_corr(
self.data,
columns=["x","y"],
**self.stat_kwargs
)
def plot(self):
# set figure axis
self._set_figure_axs()
# sort out stuff if input is data
if isinstance(self.data, pd.DataFrame):
# can reset index on dataframe without index, to try-except
try:
self.data = self.data.reset_index()
except:
pass
self.x_col = self.x
self.y_col = self.y
self.x = self.data[self.x].values
self.y = self.data[self.y].values
if isinstance(self.color_by, str):
self.color_by = self.data[self.color_by].values.astype(float)
self.data = None
else:
self.x_col = self.y_col = None
for e in [self.x,self.y]:
# convert list to array
if isinstance(e, list):
e = np.array(e)
# enforce float
e = e.astype(float)
# c-arguments clashes with "color" argument if you pass it to sns.regplot in "scatter_kws"; hence this solution
if isinstance(self.color_by, (list, np.ndarray, pd.DataFrame, pd.Series)):
# get array
if isinstance(self.color_by, (pd.DataFrame,pd.Series)):
self.color_by = self.color_by.values
# remove cbar from kwargs
if "cbar" in list(self.scatter_kwargs.keys()):
add_cbar = self.scatter_kwargs["cbar"]
_ = self.scatter_kwargs.pop("cbar")
else:
add_cbar = True
points = self.axs.scatter(
self.x,
self.y,
c=self.color_by,
**self.scatter_kwargs)
# set colorbar
if add_cbar:
self.cbar = plt.colorbar(points)
if "label" in list(self.scatter_kwargs.keys()):
self.cbar.set_label(
self.scatter_kwargs["label"],
fontsize=self.font_size,
fontname=self.fontname)
# sort out ticks
self._set_tick_params(self.cbar.ax)
self._set_spine_width(self.cbar.ax)
self._set_ticker(
self.cbar.ax,
self.y_dec,
axis="y"
)
# remove outside edge from colorbar
self.cbar.ax.set_frame_on(False)
# set stuff to false/empty for sns.regplot
self.points = False
self.scatter_kwargs = {}
# draw scatter error if desired
self._draw_errorbar(
x=self.x,
y=self.y,
ax=self.axs,
**self.error_kwargs
)
self.kde_color = utils.make_between_cm(self.color,self.color,as_list=True)
self.reg_ = sns.regplot(
x=self.x,
y=self.y,
color=self.color,
ax=self.axs,
scatter=self.points,
label=self.label,
scatter_kws=self.scatter_kwargs,
**self.reg_kwargs
)
# sort out ticks
self._set_spine_width(self.axs)
self._set_tick_params(self.axs)
self._set_title(self.axs, self.title)
# set ticks
loop_funcs = [
"_set_ticks",
"_set_ticklabels",
"_set_ticker",
"_set_axlabel"
]
for x in ["x","y"]:
for ff,el in zip(
loop_funcs,
["ticks","ticklabels","dec","label"]):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self, ff)(self.axs, add_to_ax, axis=x)
# draw horizontal/vertical lines with ax?line
self._add_line(ax=self.axs)
# set tickers & despine
self._despine(self.axs)
[docs]
class LazyBar():
"""LazyBar
Wrapper around :func:`seaborn.barplot` to follow the same aesthetics of the other Lazy* functions. It is strongly recommended to use a dataframe for this function to make the formatting somewhat easier, but you can input arrays for `x` and `y`. You can round the edges of the bar using `fancy=True`.
Parameters
----------
data: pd.DataFrame, optional
Input dataframe, by default None
x: str, list, np.ndarray, optional
Variable for the x-axis, by default None. Can be a column name from `data`, or a list/np.ndarray with labels for input `y`.
y: str, list, np.ndarray, optional
Variable for the y-axis, by default None. Can be a column name from `data`, or a list/np.ndarray. If `x` is not specified, indices from 0 to `y.shape` will be used to construct the input dataframe.
labels: list, np.ndarray, optional
custom labels that can be used when `x` denotes a column name in dataframe `data`. The replacing labels should have the same length as the labels that are being overwritten.
axs: <AxesSubplot:>, optional
Subplot axis to put the plot on, by default None
add_points: bool, optional
Add the actual datapoints rather than just the bars, by default False. Though default is `False`
points_color: str, tuple, optional
Color of the points if you do not have nested categories, by default None
points_palette: list, sns.palettes._ColorPalette, optional
Color palette for the points if you have nested categories (e.g., multiple variables per subject so you can color the individual subjects' data points), by default None
points_cmap: str, optional
Color map for the points if you did not specify `points_palette`, by default "viridis"
points_legend: bool, optional
Add legend of the data points (if you have nested categories), by default False. The functionality of these interchangeable legends (`bar_legend` and `points_legend`) is quite tricky, so user discretion is advised.
points_alpha: float, optional
Alpha of the points, by default 1. Sometimes useful to adjust if you have LOADS of data points
error: str, optional
Type of error bar to use for the bar, by default "sem". Can be {'sem'|'se'} or {'std'|'sd'. Internally, we'll check if there's enough samples to calculate errors from, otherwise `error` will be set to `None`
fancy: bool, optional
Flag to round the edges of the bars, by default False. By default, the rounding is scaled by the min/max of the plot, regardless whether `lim` was specified. This ensures equal rounding across inputs. The other `fancy`-arguments below are a bit vague, so leaving them default will ensure nice rounding of the bars
fancy_rounding: float, optional
Amount of rounding, by default 0.15
fancy_pad: float, optional
Vague variable, by default -0.004
fancy_aspect: float, optional
Vague variable, by default None. If None, the rounding is scaled by the min/max of the plot, regardless whether `lim` was specified.
fancy_denom: int, optional
Scaling factor for `fancy_aspect`, by default 4 (which works well for data where the max value is ~50). Use higher values (e.g., 6) if your data range is large
bar_legend: bool, optional
Legend for the bars, rather than points, by default False. The functionality of these interchangeable legends (`bar_legend` and `points_legend`) is quite tricky, so user discretion is advised.
strip_kw, dict, optional
Additional kwargs passed on to seaborn's stripplot. Several factors are being set via regular arguments in the function, such as `dodge`, `palette`, `color`, and `hue`.
Example
----------
>>> # this figure size works well for plots with 2 bars
>>> fig,axs = plt.subplots(figsize=(2,8))
>>> plotting.LazyBar(
>>> data=df_wm,
>>> x="group",
>>> y="t1",
>>> sns_ori="v",
>>> axs=axs,
>>> add_labels=True,
>>> palette=[con_color,mdd_color],
>>> add_points=True,
>>> points_color="k",
>>> trim_bottom=True,
>>> sns_offset=4,
>>> y_label2="white matter T1 (ms)",
>>> lim=[800,1600],
>>> fancy=True,
>>> fancy_denom=6)
Notes
----------
see documentation of :class:`linescanning.plotting.Defaults()` for formatting options
"""
def __init__(
self,
data: pd.DataFrame=None,
x: Union[str,np.ndarray]=None,
y: Union[str,np.ndarray]=None,
labels: list=None,
palette: Union[list,sns.palettes._ColorPalette]=None,
cmap: str="inferno",
hue: str=None,
figsize=(4,7),
add_labels: bool=False,
lim: list=None,
ticks: list=None,
add_points: bool=False,
points_color: Union[str,tuple]=None,
points_palette: Union[list,sns.palettes._ColorPalette]=None,
points_cmap: str="viridis",
points_legend: bool=False,
points_alpha: float=1,
error: str="se",
fancy: bool=False,
fancy_rounding: float=0.15,
fancy_pad: float=-0.004,
fancy_aspect: float=None,
fancy_denom: int=4,
bar_legend: bool=False,
lbl_legend: list=None,
strip_kw: dict={},
connect: bool=False,
connect_kw: dict={},
**kwargs):
self.data = data
self.x = x
self.y = y
self.hue = hue
self.labels = labels
self.palette = palette
self.cmap = cmap
self.add_labels = add_labels
self.lim = lim
self.ticks = ticks
self.bar_legend = bar_legend
self.lbl_legend = lbl_legend
self.add_points = add_points
self.points_color = points_color
self.points_palette = points_palette
self.points_cmap = points_cmap
self.points_legend = points_legend
self.points_alpha = points_alpha
self.error = error
self.fancy = fancy
self.fancy_rounding = fancy_rounding
self.fancy_pad = fancy_pad
self.fancy_aspect = fancy_aspect
self.fancy_denom = fancy_denom
self.figsize = figsize
self.strip_kw = strip_kw
self.connect = connect
self.connect_kw = connect_kw
self.kw_defaults = Defaults()
# avoid that these kwargs are passed down to matplotlib.bar.. Throws errors
ignore_kwargs = [
"trim_left",
"trim_bottom",
"points_hue",
"points_alpha",
"bbox_to_anchor",
"fancy",
"fancy_rounding",
"fancy_pad",
"fancy_aspect",
"fancy_denom",
"font_name",
"bar_legend",
"labels"
"strip_kw",
"fontname",
"add_legend",
"connect",
"connect_kw"
]
kw_sns = {}
for ii in kwargs:
# filter out non-ls kwargs
if ii not in self.kw_defaults.ls_kwargs+ignore_kwargs:
kw_sns[ii] = kwargs[ii]
else:
# overwrite ls-kwargs
if ii in self.kw_defaults.ls_kwargs:
if not getattr(self.kw_defaults, ii) == kwargs[ii]:
setattr(self.kw_defaults, ii, kwargs[ii])
self.__dict__.update(**self.kw_defaults.__dict__)
self.__dict__.update(**kwargs)
self.kw_defaults.update_rc(self.fontname)
if not hasattr(self, "bbox_to_anchor"):
self.bbox_to_anchor = None
if self.xkcd:
with plt.xkcd():
self.plot(**kw_sns)
else:
self.plot(**kw_sns)
# save
self.kw_defaults._save_figure(self.save_as)
def connect_hue_pairs(self, **kwargs):
n_xs = utils.get_unique_ids(self.data, id=self.xx, sort=False)
for i,x in enumerate(n_xs):
hue_data = self.data.loc[(self.data[self.xx] == x)]
self.connect_pairs(
hue_data,
self.hue,
hue_id=i,
**kwargs
)
def connect_pairs(
self,
data,
x,
hue_id=None,
**kwargs):
# first assess which elements on the x-axis we have
x_inputs = utils.get_unique_ids(
data,
id=x,
sort=False
)
# find subsequent pairs
pairs = utils.pairwise(x_inputs)
connect_kws = {
"color": "black",
"alpha": 0.1
}
for key,val in connect_kws.items():
kwargs = utils.update_kwargs(
kwargs,
key,
val
)
# get children of the axis
children = self.ff.get_children()
# make all strings so we can regex the "collections"
child_str_list = [str(i) for i in children]
str_collection = utils.get_file_from_substring("collections", child_str_list)
# then index in actual children list
real_collection = [children[child_str_list.index(i)] for i in str_collection]
# get correct collections based on hue_id
if isinstance(hue_id, int):
start_idx = hue_id*len(x_inputs)
real_collection = real_collection[start_idx:start_idx+len(x_inputs)]
# find collections
collections = utils.pairwise(real_collection)
for pair,coll in zip(pairs,collections):
# find starting index of pair list
locs1 = coll[0].get_offsets()
locs2 = coll[1].get_offsets()
# find subdatasets
sets = [data.loc[data[x] == i][self.yy].values for i in pair]
# before plotting, we need to sort so that the data points correspond to each other
sort_idxs1 = np.argsort(sets[0])
sort_idxs2 = np.argsort(sets[1])
# revert "ascending sort" through sort_idxs2.argsort(),
# and then sort into order corresponding with set1
locs2_sorted = locs2[sort_idxs2.argsort()][sort_idxs1]
for i in range(locs1.shape[0]):
if self.sns_ori == "v":
x_idx,y_idx = 0,1
else:
x_idx,y_idx = 1,0
x1 = [locs1[i, x_idx], locs2_sorted[i, x_idx]]
y1 = [locs1[i, y_idx], locs2_sorted[i, y_idx]]
self.ff.plot(x1, y1, **kwargs)
def plot(self, **kw_sns):
# set figure axis
self.kw_defaults._set_figure_axs(figsize=self.figsize)
# construct dataframe from loose inputs
if isinstance(self.y, (np.ndarray,list)):
if isinstance(self.y, list):
self.y = np.array(self.y)
if not isinstance(self.x, (np.ndarray, list)):
self.x = np.arange(0,self.y.shape[0])
self.data = pd.DataFrame({"x": self.x, "y": self.y})
self.x = "x"
self.y = "y"
# check if we should reset the index of dataframe
try:
self.data = self.data.reset_index()
except:
pass
# check if we got custom labels
if isinstance(self.labels, (np.ndarray,list)):
# check if x is multi-level
if len(self.data[self.x])>len(self.labels):
# get unique values
unique_x = utils.get_unique_ids(
self.data,
id=self.x,
sort=False
)
# new xw
new_x = []
for x in self.data[self.x].values:
# find index of old x in unique x
curr_x = unique_x.index(x)
# substitute for new x
new_x.append(self.labels[curr_x])
self.data[self.x] = new_x
else:
self.data[self.x] = self.labels
if self.sns_ori == "h":
self.xx = self.y
self.yy = self.x
self.trim_bottom = False
self.trim_left = True
elif self.sns_ori == "v":
self.xx = self.x
self.yy = self.y
self.trim_bottom = True
self.trim_left = False
else:
raise ValueError(f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
if isinstance(self.color, (str,tuple,list)):
if isinstance(self.color, (str,tuple)):
self.palette = None
self.cmap = None
elif isinstance(self.color, list):
self.palette = sns.color_palette(palette=self.color)
self.color = None
else:
self.color = None
if isinstance(self.palette, list):
self.palette = sns.color_palette(palette=self.palette)
if not isinstance(self.palette, sns.palettes._ColorPalette):
# self.palette = sns.color_palette(self.cmap, self.data.shape[0])
self.palette = self.cmap
# allow more input types
if isinstance(self.error, str):
if self.error.lower() in "sem":
self.error = "se"
elif self.error.lower() == "std":
self.error = "sd"
self.ff = sns.barplot(
data=self.data,
x=self.xx,
y=self.yy,
ax=self.axs,
orient=self.sns_ori,
errorbar=self.error,
hue=self.hue,
**dict(
kw_sns,
color=self.color,
palette=self.palette
))
multi_strip = False
if self.add_points:
if not hasattr(self, "points_hue"):
self.points_hue = None
if not self.points_palette:
self.points_palette = self.points_cmap
# give priority to given points_color
if isinstance(self.points_color, (str,tuple)):
self.points_palette = None
self.points_hue = None
if isinstance(self.hue, str):
if isinstance(self.points_hue, str):
if self.points_hue != self.hue:
multi_strip = True
self.hue_items = list(np.unique(self.data[self.points_hue].values))
if isinstance(self.points_color, (str,tuple)):
self.hue_colors = [self.points_color for ii in range(len(self.hue_items))]
else:
self.hue_colors = sns.color_palette(self.points_palette, len(self.hue_items))
for it, color in zip(self.hue_items, self.hue_colors):
df_per_it = self.data[self.data[self.points_hue] == it]
sns.stripplot(
data=df_per_it,
x=self.xx,
y=self.yy,
hue=self.hue,
dodge=False,
palette=[color] * 2,
ax=self.ff,
**self.strip_kw
)
else:
multi_strip = True
sns.stripplot(
data=self.data,
x=self.xx,
y=self.yy,
hue=self.hue,
dodge=True,
ax=self.ff,
color=self.points_color,
palette=self.points_palette,
alpha=self.points_alpha,
**self.strip_kw
)
else:
sns.stripplot(
data=self.data,
x=self.xx,
y=self.yy,
hue=self.points_hue,
dodge=False,
ax=self.ff,
color=self.points_color,
palette=self.points_palette,
alpha=self.points_alpha,
**self.strip_kw
)
# connect dots
if self.connect:
if isinstance(self.hue, str):
self.connect_hue_pairs(**self.connect_kw)
else:
self.connect_pairs(
self.data,
self.xx,
**self.connect_kw
)
# sort out legend
if self.bar_legend or self.points_legend:
self.add_legend = True
# filter out handles that correspond to labels
self.legend_kw = {}
for key,val in zip(
["fontsize","handletextpad","frameon"],
[self.label_size,self.legend_handletext,False]):
self.legend_kw[key] = val
if isinstance(self.bbox_to_anchor, tuple):
self.legend_kw["bbox_to_anchor"] = self.bbox_to_anchor
# get handles
handles,labels = self.ff.get_legend_handles_labels()
# bar legend
if self.bar_legend:
# do some more exotic stuff to disentangle coloring from bars and hue
if isinstance(self.hue, str):
# find categorical handles
handles,labels = self.ff.get_legend_handles_labels()
# find indices of categorical handles in list
cc = self.data[self.hue].values
indexes = np.unique(cc, return_index=True)[1]
cond = [cc[index] for index in sorted(indexes)]
if multi_strip:
handles = handles[-len(cond):]
labels = labels[-len(cond):]
else:
if not self.add_points:
self.add_legend = False
else:
self.add_legend = False
# fill in legend
if self.add_legend:
if isinstance(self.lbl_legend, list):
labels = self.lbl_legend
self.ff.legend(
handles,
labels,
**self.legend_kw
)
else:
self.ff.legend([],[], frameon=False)
# set tick params
self.kw_defaults._set_tick_params(self.ff)
# set spine widths
self.kw_defaults._set_spine_width(self.ff)
if not self.add_labels:
if self.sns_ori == 'h':
self.kw_defaults._set_ticks(self.ff, [], axis="y")
elif self.sns_ori == "v":
self.kw_defaults._set_ticks(self.ff, [], axis="x")
else:
raise ValueError(f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
elif isinstance(self.add_labels,list):
self.kw_defaults._set_xlabel(self.ff, self.add_labels)
if isinstance(self.sns_rot, (int,float)):
if self.sns_ori == 'h':
self.kw_defaults._set_ticklabels(
self.ff,
self.ff.get_yticklabels(),
rotation=self.sns_rot,
fontname=self.fontname,
axis="y"
)
elif self.sns_ori == "v":
self.kw_defaults._set_ticklabels(
self.ff,
self.ff.get_xticklabels(),
axis="x",
rotation=self.sns_rot,
fontname=self.fontname
)
else:
raise ValueError(f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
# set limits depending on orientation
self.kw_defaults._set_bar_lim(self.ff, self.lim)
# set ticks depending on orientation
self.kw_defaults._set_bar_ticks(self.ff, self.ticks)
# from: https://stackoverflow.com/a/61569240
if self.fancy:
new_patches = []
for patch in reversed(self.ff.patches):
bb = patch.get_bbox()
color = patch.get_facecolor()
# max of axis divided by 4 gives nice rounding
if not isinstance(self.fancy_aspect, (int,float)):
if self.sns_ori == "v":
y_limiter = patch._axes.get_ylim()[-1]
if isinstance(self.lim, list):
y_limiter-=self.lim[0]
self.fancy_aspect = y_limiter/self.fancy_denom
else:
x_limiter = patch._axes.get_xlim()[-1]
if isinstance(self.lim, list):
x_limiter-=self.lim[0]
self.fancy_aspect = x_limiter/self.fancy_denom
# make rounding at limit
if isinstance(self.lim, list):
if self.sns_ori == "v":
ymin = self.lim[0]
xmin = bb.xmin
height = bb.height - ymin
width = bb.width
else:
xmin = self.lim[0]
ymin = bb.ymin
width = bb.width - xmin
height = bb.height
else:
xmin = bb.xmin
ymin = bb.ymin
height = bb.height
width = bb.width
p_bbox = patches.FancyBboxPatch(
(xmin, ymin),
abs(width), abs(height),
boxstyle=f"round,pad={self.fancy_pad},rounding_size={self.fancy_rounding}",
ec="none",
fc=color,
mutation_aspect=self.fancy_aspect
)
patch.remove()
new_patches.append(p_bbox)
for patch in new_patches:
self.ff.add_patch(patch)
# set xlabel to none of nothing is specified
if isinstance(self.x, str) and not isinstance(self.x_label, str):
self.ff.set(xlabel=None)
if isinstance(self.y, str) and not isinstance(self.y_label, str):
self.ff.set(ylabel=None)
# set these explicitly; remove left axis is orientation = horizontal | remove bottom axis if orientation is vertical
if hasattr(self, "trim_left"):
trim_left = self.trim_left
else:
trim_left = False
if hasattr(self, "trim_bottom"):
trim_bottom = self.trim_bottom
else:
trim_bottom = False
# draw horizontal/vertical lines with ax?line
self.kw_defaults._add_line(ax=self.ff)
# set tickers & despine
# set ticks
loop_funcs = [
"_set_ticker",
"_set_axlabel"
]
for x in ["x","y"]:
for ff,el in zip(
loop_funcs,
["dec","label"]):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self.kw_defaults, ff)(self.ff, add_to_ax, axis=x)
self.kw_defaults._despine(
self.ff,
left=trim_left,
bottom=trim_bottom
)
# set title
self.kw_defaults._set_title(self.ff, self.title)
[docs]
class LazyHist(Defaults):
"""LazyHist
Wrapper around seaborn's histogram plotter
Parameters
----------
data: numpy.ndarray
Input data for histogram
kde: bool, optional
Add kernel density plot to histogram with seaborn (https://seaborn.pydata.org/generated/seaborn.kdeplot.html). Default is False
hist: bool, optional
Add histogram to plot. Default is True
fill: bool, optional
Fill the area below the kde plot. Default is False
bins: str, optional
Set bins for histogram; default = "auto"
kde_kwargs: dict, optional
Additional arguments passed on the seaborn's `kde_plot`
hist_kwargs: dict, optional
Additional arguments passed on to matplotlib's `hist` fuction
Returns
----------
matplotlib.pyplot plot
Example
----------
>>> from linescanning import plotting
>>> import matplotlib.pyplot as plt
>>> fig,axs = plt.subplots(figsize=(7,7))
>>> plotting.LazyHist(
>>> y_data,
>>> axs=axs,
>>> kde=True,
>>> hist=True,
>>> fill=False,
>>> y_label2="add y_label",
>>> x_label2="add x_label",
>>> hist_kwargs={"alpha": 0.4},
>>> kde_kwargs={"linewidth": 4}
>>> )
Notes
----------
see documentation of :class:`linescanning.plotting.Defaults()` for formatting options
"""
def __init__(
self,
data,
x=None,
y=None,
figsize=(5,5),
kde=False,
hist=True,
bins="auto",
fill=False,
kde_kwargs={},
hist_kwargs={},
color="#cccccc",
cmap=None,
fancy: bool=False,
fancy_rounding: float=0.15,
fancy_pad: float=-0.004,
fancy_aspect: float=None,
**kwargs):
super().__init__()
self.__dict__.update(kwargs)
self.update_rc(self.fontname)
# read regular arguments
self.data = data
self.x = x
self.y = y
self.figsize = figsize
self.kde = kde
self.kde_kwargs = kde_kwargs
self.hist_kwargs = hist_kwargs
self.hist = hist
self.bins = bins
self.fill = fill
self.color = color
self.kwargs = kwargs
self.fancy = fancy
self.fancy_rounding = fancy_rounding
self.fancy_pad = fancy_pad
self.fancy_aspect = fancy_aspect
self.cmap = cmap
# self.__dict__.update(self.kde_kwargs)
if self.xkcd:
with plt.xkcd():
self.plot()
else:
self.plot()
if self.kde:
try:
self.kde_ = self.return_kde()
except:
pass
# save
self._save_figure(self.save_as)
def return_kde(self):
return self.ff.get_lines()[0].get_data()
def force_kde_color(self):
line_cols = self.ff.get_lines()
line_cols[-1].set_color(self.color)
def _set_color(self):
if isinstance(self.cmap, str):
self.color = sns.color_palette(self.cmap, 1)[0]
def plot(self):
# overwrite color argument if cmap is specified
self._set_color()
# set figure axis
self._set_figure_axs()
if self.hist:
self.vals, self.bins, self.patches = self.axs.hist(
self.data,
density=True,
bins=self.bins,
color=self.color,
**self.hist_kwargs
)
# from: https://stackoverflow.com/a/61569240
if self.fancy:
new_patches = []
for patch in reversed(self.patches):
# max of axis divided by 4 gives nice rounding
if not isinstance(self.fancy_aspect, (int,float)):
self.fancy_aspect = patch._axes.get_ylim()[-1]/4
bb = patch.get_bbox()
color = patch.get_facecolor()
p_bbox = patches.FancyBboxPatch(
(bb.xmin, bb.ymin),
abs(bb.width), abs(bb.height),
boxstyle=f"round,pad={self.fancy_pad},rounding_size={self.fancy_rounding}",
ec="none",
fc=color,
mutation_aspect=self.fancy_aspect
)
patch.remove()
new_patches.append(p_bbox)
for patch in new_patches:
self.axs.add_patch(patch)
if self.kde:
# turn off legend by default
if not "legend" in list(self.kde_kwargs):
self.kde_kwargs["legend"] = False
if not self.hist:
if not "color" in list(self.kde_kwargs):
self.kde_kwargs["color"] = self.color
self.ff = sns.kdeplot(
data=self.data,
x=self.x,
y=self.y,
ax=self.axs,
fill=self.fill,
**self.kde_kwargs
)
# the color argument is very unstable for some reason..
if self.hist:
self.force_kde_color()
# there's no self.ff if kde=False
if hasattr(self, "ff"):
self.active_axs = self.ff
else:
self.active_axs = self.axs
# set titles
self._set_title(self.active_axs, self.title)
# set tick params/axis width
self._set_tick_params(self.active_axs)
self._set_spine_width(self.active_axs)
# set limits
self._set_xlim(self.active_axs, self.x_lim)
self._set_ylim(self.active_axs, self.y_lim)
# set tickers & despine
loop_funcs = [
"_set_ticks",
"_set_ticklabels",
"_set_ticker",
"_set_axlabel"
]
for x in ["x","y"]:
for ff,el in zip(
loop_funcs,
["ticks","ticklabels","dec","label"]):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self, ff)(self.active_axs, add_to_ax, axis=x)
# set axis labels
if not isinstance(self.x_label, str):
self.active_axs.set(xlabel=None)
if not isinstance(self.y_label, str):
self.active_axs.set(ylabel=None)
if hasattr(self, "trim_left"):
trim_left = self.trim_left
else:
trim_left = False
if "trim_bottom" in list(self.kwargs.keys()):
trim_bottom = self.kwargs["trim_bottom"]
else:
trim_bottom = False
self._despine(
self.active_axs,
left=trim_left,
bottom=trim_bottom
)
# draw horizontal/vertical lines with ax?line
self._add_line(ax=self.active_axs)
# set title
self._set_title(self.active_axs, self.title)
[docs]
class LazyColorbar(Defaults):
def __init__(
self,
cmap="magma_r",
txt=None,
vmin=0,
vmax=10,
ori="vertical",
ticks=None,
labels=None,
flip_ticks=False,
flip_label=False,
figsize=(6,0.5),
save_as=None,
cm_nr=5,
cm_decimal=3,
cb_kws={},
font_kws={},
**kwargs):
self.cmap = cmap
self.txt = txt
self.vmin = vmin
self.vmax = vmax
self.ori = ori
self.ticks = ticks
self.flip_ticks = flip_ticks
self.flip_label = flip_label
self.figsize = figsize
self.save_as = save_as
self.cm_nr = cm_nr
self.cm_decimal = cm_decimal
self.labels = labels
self.font_kws = font_kws
if self.ori == "vertical":
self.figsize = (self.figsize[1],self.figsize[0])
super().__init__(**kwargs)
# set figure axis
self._set_figure_axs()
self.__dict__.update(kwargs)
self.update_rc(self.fontname)
if self.axs == None:
if isinstance(self.save_as, str):
self.fig, self.axs = plt.subplots(figsize=self.figsize)
else:
self.fig, self.axs = plt.subplots(figsize=self.figsize)
# make colorbase instance
if isinstance(self.cmap, str):
self.cmap = mpl.cm.get_cmap(self.cmap, 256)
# decide ticks
if not isinstance(self.ticks, (np.ndarray,list)):
self.ticks = self.colormap_ticks(
vmin=self.vmin,
vmax=self.vmax,
key=self.txt,
dec=self.cm_decimal,
nr=self.cm_nr
)
# plop everything in class
mpl.colorbar.Colorbar(
self.axs,
orientation=self.ori,
cmap=self.cmap,
norm=mpl.colors.Normalize(vmin,vmax),
label=self.txt,
ticks=self.ticks,
**cb_kws
)
if self.ori == "vertical":
tick_ax = "y"
# set font stuff
if self.flip_ticks:
self.axs.yaxis.set_ticks_position("left")
if self.flip_label:
self.axs.yaxis.set_label_position("left")
text = self.axs.yaxis.label
else:
tick_ax = "x"
if self.flip_ticks:
self.axs.xaxis.set_ticks_position("top")
if self.flip_label:
self.axs.xaxis.set_label_position("top")
text = self.axs.xaxis.label
if "color" in list(self.font_kws.keys()):
ft_color = self.font_kws["color"]
self.font_kws.pop('color')
else:
ft_color = "black"
font = mpl.font_manager.FontProperties(
size=self.font_size,
**self.font_kws
)
text.set_font_properties(font)
text.set_color(ft_color)
# fix ticks
self._set_tick_params(
self.axs,
labelcolor=ft_color,
)
# set tick labels?
if isinstance(self.labels, (np.ndarray,list)):
self._set_ticklabels(
self.axs,
self.labels,
axis=tick_ax
)
# turn off frame
self.axs.set_frame_on(False)
# save
self._save_figure(self.save_as)
@staticmethod
def colormap_ticks(
vmin=None,
vmax=None,
key=None,
dec=3,
nr=5):
# store colormaps
if isinstance(key, str):
if key == "polar" or key == "polar angle" or "polar" in key:
ticks = [-np.pi,0,np.pi]
else:
ticks = list(np.linspace(vmin,vmax, endpoint=True, num=nr))
else:
ticks = list(np.linspace(vmin,vmax, endpoint=True, num=nr))
# round ticks
ticks = [round(ii,dec) for ii in ticks]
# check if minimum of ticks > minimum of data
if ticks[0] < vmin:
ticks[0] = utils.round_decimals_up(vmin, dec)
# check if maximum of ticks < maximum of data
if ticks[-1] > vmax:
ticks[-1] = utils.round_decimals_down(vmax, dec)
return ticks
def show(self):
fig = plt.figure()
new_manager = fig.canvas.manager
new_manager.canvas.figure = self.fig
self.fig.set_canvas(new_manager.canvas)
def fig_annot(
fig,
axs=None,
y=1.01,
x0_corr=0,
x_corr=-0.09,
fontsize=28,
lower=False,
brackets=False,
square=False,
**kwargs):
# get figure letters
if lower:
alphabet = list(string.ascii_lowercase)
else:
alphabet = list(string.ascii_uppercase)
if isinstance(axs, list):
ax_list = axs
else:
ax_list = fig.axes
# make annotations
for ix,ax in enumerate(ax_list):
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
if isinstance(x_corr, (float,int)):
use_x_pos = x_corr
else:
if len(x_corr) != len(ax_list)-1:
raise ValueError(f"List with 'x_corr'-values must match list with axes-1. x_corr contains {len(x_corr)} elements, while {len(ax_list)-1} axes are specified")
use_x_pos = x_corr[ix-1]
if ix == 0:
move_frac = x0_corr/bbox.width
else:
move_frac = use_x_pos/bbox.width
pos = move_frac
if isinstance(y, list):
if len(y) != len(ax_list):
raise ValueError(f"List with y-values must match list with axes. y contains {len(y)} elements, while {len(ax_list)} axes are specified")
y_pos = y[ix]
else:
y_pos = y
txt = alphabet[ix]
if brackets:
if square:
txt = f"[{txt}]"
else:
txt = f"({txt})"
ax.annotate(
txt,
(pos,y_pos),
fontsize=fontsize,
xycoords="axes fraction",
**kwargs)
turbo_colormap_data = np.array(
[[0.18995,0.07176,0.23217],
[0.19483,0.08339,0.26149],
[0.19956,0.09498,0.29024],
[0.20415,0.10652,0.31844],
[0.20860,0.11802,0.34607],
[0.21291,0.12947,0.37314],
[0.21708,0.14087,0.39964],
[0.22111,0.15223,0.42558],
[0.22500,0.16354,0.45096],
[0.22875,0.17481,0.47578],
[0.23236,0.18603,0.50004],
[0.23582,0.19720,0.52373],
[0.23915,0.20833,0.54686],
[0.24234,0.21941,0.56942],
[0.24539,0.23044,0.59142],
[0.24830,0.24143,0.61286],
[0.25107,0.25237,0.63374],
[0.25369,0.26327,0.65406],
[0.25618,0.27412,0.67381],
[0.25853,0.28492,0.69300],
[0.26074,0.29568,0.71162],
[0.26280,0.30639,0.72968],
[0.26473,0.31706,0.74718],
[0.26652,0.32768,0.76412],
[0.26816,0.33825,0.78050],
[0.26967,0.34878,0.79631],
[0.27103,0.35926,0.81156],
[0.27226,0.36970,0.82624],
[0.27334,0.38008,0.84037],
[0.27429,0.39043,0.85393],
[0.27509,0.40072,0.86692],
[0.27576,0.41097,0.87936],
[0.27628,0.42118,0.89123],
[0.27667,0.43134,0.90254],
[0.27691,0.44145,0.91328],
[0.27701,0.45152,0.92347],
[0.27698,0.46153,0.93309],
[0.27680,0.47151,0.94214],
[0.27648,0.48144,0.95064],
[0.27603,0.49132,0.95857],
[0.27543,0.50115,0.96594],
[0.27469,0.51094,0.97275],
[0.27381,0.52069,0.97899],
[0.27273,0.53040,0.98461],
[0.27106,0.54015,0.98930],
[0.26878,0.54995,0.99303],
[0.26592,0.55979,0.99583],
[0.26252,0.56967,0.99773],
[0.25862,0.57958,0.99876],
[0.25425,0.58950,0.99896],
[0.24946,0.59943,0.99835],
[0.24427,0.60937,0.99697],
[0.23874,0.61931,0.99485],
[0.23288,0.62923,0.99202],
[0.22676,0.63913,0.98851],
[0.22039,0.64901,0.98436],
[0.21382,0.65886,0.97959],
[0.20708,0.66866,0.97423],
[0.20021,0.67842,0.96833],
[0.19326,0.68812,0.96190],
[0.18625,0.69775,0.95498],
[0.17923,0.70732,0.94761],
[0.17223,0.71680,0.93981],
[0.16529,0.72620,0.93161],
[0.15844,0.73551,0.92305],
[0.15173,0.74472,0.91416],
[0.14519,0.75381,0.90496],
[0.13886,0.76279,0.89550],
[0.13278,0.77165,0.88580],
[0.12698,0.78037,0.87590],
[0.12151,0.78896,0.86581],
[0.11639,0.79740,0.85559],
[0.11167,0.80569,0.84525],
[0.10738,0.81381,0.83484],
[0.10357,0.82177,0.82437],
[0.10026,0.82955,0.81389],
[0.09750,0.83714,0.80342],
[0.09532,0.84455,0.79299],
[0.09377,0.85175,0.78264],
[0.09287,0.85875,0.77240],
[0.09267,0.86554,0.76230],
[0.09320,0.87211,0.75237],
[0.09451,0.87844,0.74265],
[0.09662,0.88454,0.73316],
[0.09958,0.89040,0.72393],
[0.10342,0.89600,0.71500],
[0.10815,0.90142,0.70599],
[0.11374,0.90673,0.69651],
[0.12014,0.91193,0.68660],
[0.12733,0.91701,0.67627],
[0.13526,0.92197,0.66556],
[0.14391,0.92680,0.65448],
[0.15323,0.93151,0.64308],
[0.16319,0.93609,0.63137],
[0.17377,0.94053,0.61938],
[0.18491,0.94484,0.60713],
[0.19659,0.94901,0.59466],
[0.20877,0.95304,0.58199],
[0.22142,0.95692,0.56914],
[0.23449,0.96065,0.55614],
[0.24797,0.96423,0.54303],
[0.26180,0.96765,0.52981],
[0.27597,0.97092,0.51653],
[0.29042,0.97403,0.50321],
[0.30513,0.97697,0.48987],
[0.32006,0.97974,0.47654],
[0.33517,0.98234,0.46325],
[0.35043,0.98477,0.45002],
[0.36581,0.98702,0.43688],
[0.38127,0.98909,0.42386],
[0.39678,0.99098,0.41098],
[0.41229,0.99268,0.39826],
[0.42778,0.99419,0.38575],
[0.44321,0.99551,0.37345],
[0.45854,0.99663,0.36140],
[0.47375,0.99755,0.34963],
[0.48879,0.99828,0.33816],
[0.50362,0.99879,0.32701],
[0.51822,0.99910,0.31622],
[0.53255,0.99919,0.30581],
[0.54658,0.99907,0.29581],
[0.56026,0.99873,0.28623],
[0.57357,0.99817,0.27712],
[0.58646,0.99739,0.26849],
[0.59891,0.99638,0.26038],
[0.61088,0.99514,0.25280],
[0.62233,0.99366,0.24579],
[0.63323,0.99195,0.23937],
[0.64362,0.98999,0.23356],
[0.65394,0.98775,0.22835],
[0.66428,0.98524,0.22370],
[0.67462,0.98246,0.21960],
[0.68494,0.97941,0.21602],
[0.69525,0.97610,0.21294],
[0.70553,0.97255,0.21032],
[0.71577,0.96875,0.20815],
[0.72596,0.96470,0.20640],
[0.73610,0.96043,0.20504],
[0.74617,0.95593,0.20406],
[0.75617,0.95121,0.20343],
[0.76608,0.94627,0.20311],
[0.77591,0.94113,0.20310],
[0.78563,0.93579,0.20336],
[0.79524,0.93025,0.20386],
[0.80473,0.92452,0.20459],
[0.81410,0.91861,0.20552],
[0.82333,0.91253,0.20663],
[0.83241,0.90627,0.20788],
[0.84133,0.89986,0.20926],
[0.85010,0.89328,0.21074],
[0.85868,0.88655,0.21230],
[0.86709,0.87968,0.21391],
[0.87530,0.87267,0.21555],
[0.88331,0.86553,0.21719],
[0.89112,0.85826,0.21880],
[0.89870,0.85087,0.22038],
[0.90605,0.84337,0.22188],
[0.91317,0.83576,0.22328],
[0.92004,0.82806,0.22456],
[0.92666,0.82025,0.22570],
[0.93301,0.81236,0.22667],
[0.93909,0.80439,0.22744],
[0.94489,0.79634,0.22800],
[0.95039,0.78823,0.22831],
[0.95560,0.78005,0.22836],
[0.96049,0.77181,0.22811],
[0.96507,0.76352,0.22754],
[0.96931,0.75519,0.22663],
[0.97323,0.74682,0.22536],
[0.97679,0.73842,0.22369],
[0.98000,0.73000,0.22161],
[0.98289,0.72140,0.21918],
[0.98549,0.71250,0.21650],
[0.98781,0.70330,0.21358],
[0.98986,0.69382,0.21043],
[0.99163,0.68408,0.20706],
[0.99314,0.67408,0.20348],
[0.99438,0.66386,0.19971],
[0.99535,0.65341,0.19577],
[0.99607,0.64277,0.19165],
[0.99654,0.63193,0.18738],
[0.99675,0.62093,0.18297],
[0.99672,0.60977,0.17842],
[0.99644,0.59846,0.17376],
[0.99593,0.58703,0.16899],
[0.99517,0.57549,0.16412],
[0.99419,0.56386,0.15918],
[0.99297,0.55214,0.15417],
[0.99153,0.54036,0.14910],
[0.98987,0.52854,0.14398],
[0.98799,0.51667,0.13883],
[0.98590,0.50479,0.13367],
[0.98360,0.49291,0.12849],
[0.98108,0.48104,0.12332],
[0.97837,0.46920,0.11817],
[0.97545,0.45740,0.11305],
[0.97234,0.44565,0.10797],
[0.96904,0.43399,0.10294],
[0.96555,0.42241,0.09798],
[0.96187,0.41093,0.09310],
[0.95801,0.39958,0.08831],
[0.95398,0.38836,0.08362],
[0.94977,0.37729,0.07905],
[0.94538,0.36638,0.07461],
[0.94084,0.35566,0.07031],
[0.93612,0.34513,0.06616],
[0.93125,0.33482,0.06218],
[0.92623,0.32473,0.05837],
[0.92105,0.31489,0.05475],
[0.91572,0.30530,0.05134],
[0.91024,0.29599,0.04814],
[0.90463,0.28696,0.04516],
[0.89888,0.27824,0.04243],
[0.89298,0.26981,0.03993],
[0.88691,0.26152,0.03753],
[0.88066,0.25334,0.03521],
[0.87422,0.24526,0.03297],
[0.86760,0.23730,0.03082],
[0.86079,0.22945,0.02875],
[0.85380,0.22170,0.02677],
[0.84662,0.21407,0.02487],
[0.83926,0.20654,0.02305],
[0.83172,0.19912,0.02131],
[0.82399,0.19182,0.01966],
[0.81608,0.18462,0.01809],
[0.80799,0.17753,0.01660],
[0.79971,0.17055,0.01520],
[0.79125,0.16368,0.01387],
[0.78260,0.15693,0.01264],
[0.77377,0.15028,0.01148],
[0.76476,0.14374,0.01041],
[0.75556,0.13731,0.00942],
[0.74617,0.13098,0.00851],
[0.73661,0.12477,0.00769],
[0.72686,0.11867,0.00695],
[0.71692,0.11268,0.00629],
[0.70680,0.10680,0.00571],
[0.69650,0.10102,0.00522],
[0.68602,0.09536,0.00481],
[0.67535,0.08980,0.00449],
[0.66449,0.08436,0.00424],
[0.65345,0.07902,0.00408],
[0.64223,0.07380,0.00401],
[0.63082,0.06868,0.00401],
[0.61923,0.06367,0.00410],
[0.60746,0.05878,0.00427],
[0.59550,0.05399,0.00453],
[0.58336,0.04931,0.00486],
[0.57103,0.04474,0.00529],
[0.55852,0.04028,0.00579],
[0.54583,0.03593,0.00638],
[0.53295,0.03169,0.00705],
[0.51989,0.02756,0.00780],
[0.50664,0.02354,0.00863],
[0.49321,0.01963,0.00955],
[0.47960,0.01583,0.01055]]
)
def RGBToPyCmap(rgbdata):
nsteps = rgbdata.shape[0]
stepaxis = np.linspace(0, 1, nsteps)
rdata=[]; gdata=[]; bdata=[]
for istep in range(nsteps):
r = rgbdata[istep,0]
g = rgbdata[istep,1]
b = rgbdata[istep,2]
rdata.append((stepaxis[istep], r, r))
gdata.append((stepaxis[istep], g, g))
bdata.append((stepaxis[istep], b, b))
mpl_data = {
'red': rdata,
'green': gdata,
'blue': bdata
}
return mpl_data
mpl_data = RGBToPyCmap(turbo_colormap_data)
try:
plt.register_cmap(
name='turbo',
data=mpl_data,
lut=turbo_colormap_data.shape[0]
)
except:
pass
mpl_data_r = RGBToPyCmap(turbo_colormap_data[::-1,:])
try:
plt.register_cmap(
name='turbo_r',
data=mpl_data_r,
lut=turbo_colormap_data.shape[0]
)
except:
pass