Source code for nrv.eit.utils._eit_plot

import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
from pandas import DataFrame
import seaborn as sns

from ..results import eit_results_list

from ...backend import load_any
from ...fmod import CUFF_MP_electrode
from ...ui import load_nerve
from ...utils import sci_round, get_MRG_parameters


[docs] class Figure_elec:
[docs] def __init__( self, n_e, fig=None, spec=None, ij_offset=(0, 0), small_fig=False, **fig_kwgs ): self.n_e = n_e self._fig = fig self.spec = spec self.ij_offset = ij_offset self.small_fig = small_fig self.__init_figure(**fig_kwgs)
@property def fig(self): return self._fig @property def axs(self): return self._axs def __init_figure(self, **fig_kwgs): if self.n_e == 8: fig_dim = (3, 3) i_plots = np.array( [[0, 1], [0, 2], [1, 2], [2, 2], [2, 1], [2, 0], [1, 0], [0, 0]] ) i_center = (1, 1) elif self.n_e == 12: fig_dim = (4, 4) i_plots = np.array( [ [0, 1], [0, 2], [0, 3], [1, 3], [2, 3], [3, 3], [3, 2], [3, 1], [3, 0], [1, 0], [0, 0], ] ) idxs = np.arange(2) + 1 i_center = (slice(1, 3), slice(1, 3)) elif self.n_e == 16: fig_dim = (5, 5) i_plots = np.array( [ [0, 2], [0, 3], [0, 4], [1, 4], [2, 4], [3, 4], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0], [3, 0], [2, 0], [1, 0], [0, 0], [0, 1], ] ) i_center = (slice(1, 4), slice(1, 4)) mask = np.ones(fig_dim, dtype=bool) mask[*i_center] = False if self.spec is None: if self._fig is None: self._fig = plt.figure(**fig_kwgs) self.spec = self._fig.add_gridspec(*fig_dim) else: i_plots[:, 0] += self.ij_offset[0] i_plots[:, 1] += self.ij_offset[1] i_center[0] += self.ij_offset[0] i_center[1] += self.ij_offset[1] # fig, axs = plt.subplots(*fig_dim) self._axs = [] for i_ in range(self.n_e): ax_ = self._fig.add_subplot(self.spec[*i_plots[i_]]) if i_plots[i_][0] == fig_dim[0] - 1 and self.small_fig: ax_.set_title(f"E{i_}", y=-0.1) else: ax_.set_title(f"E{i_}") ax_.set_axis_off() self._axs += [ax_] self._axs += [self._fig.add_subplot(self.spec[*i_center])] def __setup_data(self, data, t=None, i_res=0, which="dv_eit"): if isinstance(data, eit_results_list): if self.n_e == data.n_e: _dv = data.get_res(i_res=i_res, which=which) _t = data.t() else: raise ValueError( f"Wrong number of electrodes (figure: {self.n_e}, res_list: {data.n_e}) Data cannot be ploted" ) elif isinstance(data, np.ndarray): if self.n_e == data.shape[-1]: _dv = data _t = t else: raise ValueError( f"Wrong number of electrodes (figure: {self.n_e}, data: {data.shape[-1]}) Data cannot be ploted" ) return _dv, _t
[docs] def plot_all_elec(self, data, t=None, i_res=0, which="dv_eit", **kwgs): if isinstance(data, list): for _i_d, _d in enumerate(data): if isinstance(t, list): self.plot_all_elec( data=_d, t=t[_i_d], i_res=i_res, which=which, **kwgs ) else: self.plot_all_elec(data=_d, t=t, i_res=i_res, which=which, **kwgs) else: _dv, _t = self.__setup_data( data=data, t=t, i_res=i_res, which=which, ) if _dv.ndim == 0: raise ValueError("Not enough dimensions. Data cannot be ploted") if _dv.ndim == 1: for i_e in range(self.n_e): self.axs[i_e].plot(_t, _dv, **kwgs) elif _dv.ndim == 2: for i_e in range(self.n_e): dv_i = _dv[:, i_e] self.axs[i_e].plot(_t, dv_i, **kwgs) elif _dv.ndim == 3: n_plots = _dv.shape[0] t_i = deepcopy(_t) for i_p in range(n_plots): if _t.ndim == 2: t_i = _t[i_p] for i_e in range(self.n_e): dv_i = _dv[i_p, :, i_e] self.axs[i_e].plot(t_i, dv_i, **kwgs) return self.axs
[docs] def boxplot(self, data: DataFrame, expr="", **kwgs): if not "i_e" in data: raise ValueError("no electrode colone in DataFrame") if len(expr) != 0: expr += " and " kwgs["legend"] = False for i_e in range(self.n_e): _ax = self.axs[i_e] _subdata = data.query(expr + f"i_e=={i_e}") p = sns.boxplot(ax=_ax, data=_subdata, **kwgs) return self.axs
[docs] def snsplot(data: DataFrame, type="lineplot", expr="", **kwgs): pass
[docs] def fill_between_all_elec( self, data_1, data_2, t=None, i_res=0, which="dv_eit", **kwgs ): _dv_1, _ = self.__setup_data( data=data_1, t=t, i_res=i_res, which=which, ) _dv_2, _t = self.__setup_data( data=data_2, t=t, i_res=i_res, which=which, ) is_multi_t = _t.ndim > 1 t_i = deepcopy(_t) for i_e in range(self.n_e): ax_ = self.axs[i_e] dv_i_1 = _dv_1[..., i_e] dv_i_2 = _dv_2[..., i_e] if is_multi_t: t_i = _t[0] if len(t.shape) < len(dv_i_1.shape): dv_i_1 = dv_i_1.T dv_i_1 = dv_i_1.T if len(t.shape) < len(dv_i_2.shape): dv_i_2 = dv_i_2.T dv_i_2 = dv_i_2.T ax_.fill_between(t_i, dv_i_1, dv_i_2, **kwgs) return self.axs
[docs] def add_nerve_plot( self, data, add_elec=True, drive_pair=(0, 2), e_label=True, n_lwidth=2, e_lwidth=3, **kwgs, ): nerve = load_any(data) nerve.plot(self.axs[-1], linewidth=n_lwidth) self.axs[-1].set_axis_off() if add_elec: if "n_e" not in kwgs: n_e = len(self.axs) - 1 else: n_e = kwgs.pop("n_e") w_elec = 0.5 * np.pi * nerve.D / n_e elec = CUFF_MP_electrode( N_contact=n_e, x_center=100, contact_width=w_elec, contact_length=100, insulator=False, ) elec.plot( self.axs[-1], nerve_d=nerve.D, color="k", e_label=e_label, linewidth=e_lwidth, ) if drive_pair is not None: elec.plot( self.axs[-1], nerve_d=nerve.D, list_e=drive_pair[0], color="r", e_label=False, linewidth=e_lwidth, ) elec.plot( self.axs[-1], nerve_d=nerve.D, list_e=drive_pair[1], color="b", e_label=False, linewidth=e_lwidth, ) del nerve return self.axs
[docs] def color_elec(self, data, n_e, list_e, **kwgs): nerve = load_any(data) w_elec = 0.5 * np.pi * nerve.D / n_e elec = CUFF_MP_electrode( N_contact=n_e, x_center=100, contact_width=w_elec, contact_length=100, insulator=False, ) elec.plot(self.axs[-1], nerve_d=nerve.D, list_e=list_e, **kwgs) del nerve return self.axs
[docs] def scale_axs( self, i_ax=-2, unit_x="ms", unit_y="V", e_gnd=[0], zerox=False, zeroy=False, has_nerve=True, ): if has_nerve: __axs = self.axs[:-1] else: __axs = self.axs min_y, max_y = 0, 0 for i_e, ax_ in enumerate(__axs): if i_e not in e_gnd: _min_y, _max_y = ax_.get_ylim() min_y = min(_min_y, min_y) max_y = max(_max_y, max_y) for ax_ in __axs: ax_.set_ylim(min_y, max_y) if zerox: _min_x, _max_x = ax_.get_xlim() ax_.plot([_min_x, _max_x], [0, 0], "k") if zeroy: ax_.plot([0, 0], [_min_y, _max_y], "k") if np.iterable(i_ax): y_i_ax = i_ax[0] t_i_ax = i_ax[1] else: y_i_ax, t_i_ax = i_ax, i_ax if i_ax is None: return self.axs # Adding y scale ax_y = self.axs[y_i_ax] ax_t = self.axs[t_i_ax] _min_y, _max_y = ax_.get_ylim() Dy = _max_y - _min_y _min_t, _max_t = ax_.get_xlim() Dt = _max_t - _min_t scale_t = sci_round(0.2 * (Dt), 1) if scale_t >= 1: scale_t = int(scale_t) x_st = [0.1 * (Dt), 0.1 * (Dt) + scale_t] y_st = [_min_y + 0.1 * Dy, _min_y + 0.1 * Dy] ax_t.plot(x_st, y_st, color="k", linewidth=3) ax_t.text( x_st[0], y_st[0] + 0.01 * Dy, f"{scale_t}{unit_x}", ha="left", va="bottom", style="italic", ) # Adding t(x) scale scale_y = sci_round(0.2 * Dy, 1) if scale_y >= 1: scale_y = int(scale_y) x_sy = [_min_t + 0.1 * (Dt), _min_t + 0.1 * (Dt)] y_sy = [_min_y + 0.2 * Dy, _min_y + 0.2 * Dy + scale_y] ax_y.plot(x_sy, y_sy, color="k", linewidth=3) ax_y.text( x_sy[0] + 0.05 * Dt, np.mean(y_sy), f"{scale_y}{unit_y}", ha="left", va="center", style="italic", ) return self.axs
[docs] def gen_fig_elec( n_e, fig=None, spec=None, ij_offset=(0, 0), small_fig=False, **fig_kwgs ): if n_e == 8: fig_dim = (3, 3) i_plots = np.array( [[0, 1], [0, 2], [1, 2], [2, 2], [2, 1], [2, 0], [1, 0], [0, 0]] ) i_center = (1, 1) elif n_e == 12: fig_dim = (4, 4) i_plots = np.array( [ [0, 1], [0, 2], [0, 3], [1, 3], [2, 3], [3, 3], [3, 2], [3, 1], [3, 0], [1, 0], [0, 0], ] ) idxs = np.arange(2) + 1 i_center = (idxs[:, np.newaxis], idxs) elif n_e == 16: fig_dim = (5, 5) i_plots = np.array( [ [0, 2], [0, 3], [0, 4], [1, 4], [2, 4], [3, 4], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0], [3, 0], [2, 0], [1, 0], [0, 0], [0, 1], ] ) # idxs = np.arange(3)+1 # i_center = (idxs[:,np.newaxis],idxs) i_center = (slice(1, 4), slice(1, 4)) mask = np.ones(fig_dim, dtype=bool) mask[*i_center] = False if spec is None: if fig is None: fig = plt.figure(**fig_kwgs) spec = fig.add_gridspec(*fig_dim) else: i_plots[:, 0] += ij_offset[0] i_plots[:, 1] += ij_offset[1] i_center[0] += ij_offset[0] i_center[1] += ij_offset[1] # fig, axs = plt.subplots(*fig_dim) axs = [] for i_ in range(n_e): ax_ = fig.add_subplot(spec[*i_plots[i_]]) if i_plots[i_][0] == fig_dim[0] - 1 and small_fig: ax_.set_title(f"E{i_}", y=-0.1) else: ax_.set_title(f"E{i_}") ax_.set_axis_off() axs += [ax_] axs += [fig.add_subplot(spec[*i_center])] return fig, axs
[docs] def plot_all_elec( res_list, t=None, axs=None, i_res=0, which="dv_eit", same_scale=True, labels=None, **kwgs, ): if isinstance(res_list, eit_results_list): n_e = res_list.n_e _dv = res_list.get_res(i_res=i_res, which=which) t = res_list.t() elif isinstance(res_list, np.ndarray): n_e = res_list.shape[-1] _dv = res_list is_generated = axs is None if is_generated: fig, axs = gen_fig_elec(n_e) for i_e in range(n_e): ax_ = axs[i_e] dv_i = _dv[..., i_e] if len(t.shape) < len(dv_i.shape): dv_i = dv_i.T ax_.plot(t, dv_i, **kwgs) ax_.set_xlim(0, t[-1]) if is_generated: return fig, axs return axs
[docs] def fill_between_all_elec(axs, res_list_1, res_list_2, t=None, **kwgs): if isinstance(res_list_1, np.ndarray): n_e = res_list_1.shape[-1] dv_1 = res_list_1 dv_2 = res_list_2 for i_e in range(n_e): ax_ = axs[i_e] dv_i_1 = dv_1[..., i_e] dv_i_2 = dv_2[..., i_e] if len(t.shape) < len(dv_i_1.shape): dv_i_1 = dv_i_1.T dv_i_1 = dv_i_1.T if len(t.shape) < len(dv_i_2.shape): dv_i_2 = dv_i_2.T dv_i_2 = dv_i_2.T ax_.fill_between(t, dv_i_1, dv_i_2, **kwgs) return axs
[docs] def add_nerve_plot( axs, data, add_elec=True, drive_pair=(0, 2), e_label=True, n_lwidth=2, e_lwidth=3, **kwgs, ): nerve = load_any(data) nerve.plot(axs[-1], linewidth=n_lwidth) axs[-1].set_axis_off() if add_elec: if "n_e" not in kwgs: n_e = len(axs) - 1 else: n_e = kwgs.pop("n_e") w_elec = 0.5 * np.pi * nerve.D / n_e elec = CUFF_MP_electrode( N_contact=n_e, x_center=100, contact_width=w_elec, contact_length=100, insulator=False, ) elec.plot( axs[-1], nerve_d=nerve.D, color="k", e_label=e_label, linewidth=e_lwidth ) elec.plot( axs[-1], nerve_d=nerve.D, list_e=drive_pair[0], color="r", e_label=False, linewidth=e_lwidth, ) elec.plot( axs[-1], nerve_d=nerve.D, list_e=drive_pair[1], color="b", e_label=False, linewidth=e_lwidth, ) del nerve return axs
[docs] def color_elec(axs, data, n_e, list_e, **kwgs): nerve = load_any(data) w_elec = 0.5 * np.pi * nerve.D / n_e elec = CUFF_MP_electrode( N_contact=n_e, x_center=100, contact_width=w_elec, contact_length=100, insulator=False, ) elec.plot(axs[-1], nerve_d=nerve.D, list_e=list_e, **kwgs) del nerve return axs
[docs] def scale_axs( axs, i_ax=-2, unit_x="ms", unit_y="V", e_gnd=[0], zerox=False, zeroy=False, has_nerve=True, ): if has_nerve: __axs = axs[:-1] else: __axs = axs min_y, max_y = 0, 0 for i_e, ax_ in enumerate(__axs): if i_e not in e_gnd: _min_y, _max_y = ax_.get_ylim() min_y = min(_min_y, min_y) max_y = max(_max_y, max_y) for ax_ in __axs: ax_.set_ylim(min_y, max_y) if zerox: _min_x, _max_x = ax_.get_xlim() ax_.plot([_min_x, _max_x], [0, 0], "k") if zeroy: ax_.plot([0, 0], [_min_y, _max_y], "k") if i_ax is None: return axs if np.iterable(i_ax): y_i_ax = i_ax[0] t_i_ax = i_ax[1] else: y_i_ax, t_i_ax = i_ax, i_ax # Adding y scale ax_y = axs[y_i_ax] ax_t = axs[t_i_ax] _min_y, _max_y = ax_.get_ylim() Dy = _max_y - _min_y _min_t, _max_t = ax_.get_xlim() Dt = _max_t - _min_t scale_t = sci_round(0.2 * (Dt), 1) if scale_t >= 1: scale_t = int(scale_t) x_st = [0.1 * (Dt), 0.1 * (Dt) + scale_t] y_st = [_min_y + 0.1 * Dy, _min_y + 0.1 * Dy] ax_t.plot(x_st, y_st, color="k", linewidth=3) ax_t.text( x_st[0], y_st[0] + 0.01 * Dy, f"{scale_t}{unit_x}", ha="left", va="bottom", style="italic", ) # Adding t(x) scale scale_y = sci_round(0.2 * Dy, 1) if scale_y >= 1: scale_y = int(scale_y) x_sy = [_min_t + 0.1 * (Dt), _min_t + 0.1 * (Dt)] y_sy = [_min_y + 0.2 * Dy, _min_y + 0.2 * Dy + scale_y] ax_y.plot(x_sy, y_sy, color="k", linewidth=3) ax_y.text( x_sy[0] + 0.05 * Dt, np.mean(y_sy), f"{scale_y}{unit_y}", ha="left", va="center", style="italic", ) return axs
[docs] def plot_nerve_nor(fname, l_elec, x_rec, fasc_ID="1"): fasc_ID = str(fasc_ID) x_min, x_max = x_rec - l_elec / 2, x_rec + l_elec / 2 fasc = load_nerve(fname).fascicles[fasc_ID] fig, axs = plt.subplots(3) fasc.plot_x(axs[0]) del fasc fasc = load_nerve(fname).fascicles[fasc_ID] fasc.NoR_relative_position *= 0 fasc.plot_x(axs[2]) deltaxs = get_MRG_parameters(fasc.axons_diameter)[5] for i, dx in enumerate(deltaxs): axs[2].plot([0, dx], [i, i]) del fasc axs[2].axvline(x_min, color="red") axs[2].axvline(x_max, color="red") fasc1 = load_nerve(fname).fascicles[fasc_ID] axs[0].set_xlim(x_min, x_max) deltaxs = get_MRG_parameters(fasc1.axons_diameter)[5] fasc1.define_length(x_max - x_min) l1 = fasc1.NoR_relative_position * deltaxs x_l = np.mod((l1 - x_min), deltaxs) fasc1.NoR_relative_position = x_l / deltaxs fasc1.plot_x( axs[1], ) return fig, axs