import numpy as np
import os
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt
from copy import deepcopy
from ...fmod.FEM.fenics_utils import get_sig_ap
from ...nmod.results import nerve_results, axon_results
from ...ui import sample_keys
from ...utils import sci_round, get_MRG_parameters, convert
# File handling
def touch(path):
with open(path, "a"):
os.utime(path, None)
# Numpy usefull
def gen_from_idx(idx: np.ndarray, n: int, add_0: bool = False) -> np.ndarray:
_arr = np.arange(n)
if np.sum(idx == 0) > 0:
_arr += np.sum(idx == 0)
_arr = np.searchsorted(idx, _arr)
if add_0:
_arr = np.concatenate((_arr[:1], _arr))
return _arr
[docs]
def gen_idx_arange(idx: np.ndarray, n: int, add_0: bool = False) -> np.ndarray:
idx = np.concatenate(([0], idx, [n]))
positions = np.arange(n)
section = np.searchsorted(idx[1:], positions, side="right")
# Subtract start index of that section to get counter within each segment
_arr = positions - idx[section]
if add_0:
_arr[: idx[0]] += 1
_arr = np.concatenate(([0], _arr))
return _arr
def adjust_axes(arr1: np.ndarray, arr2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
add missing dimension to arr2 for it to be broadastable with arr1
Warning
-------
Only reshape arr2 to match arr1's ndim, arr2.shape is must be included in arr1.shape
Parameters
----------
arr1 : np.ndarray
_description_
arr2 : np.ndarray
_description_
Returns
-------
tuple[np.ndarray, np.ndarray]
_description_
Raises
------
ValueError
_description_
"""
arr2 = arr2.squeeze()
ndim1 = arr1.ndim
ndim2 = arr2.ndim
# sh1 = arr1.shape
sh1 = list(arr1.shape)
sh2 = np.array(arr2.shape, dtype=int)
if ndim2 > ndim1:
raise ValueError(
"arr2 cannot have more dimensions than arr1 under this assumption"
)
# Prepend singleton dims to arr2 to match arr1's ndim
shape2_adj = np.ones(ndim1, dtype=int) # + arr2.shape
i_shape2 = []
i_of = 0
# The loop is required for cases when two axes have the same length
for s in sh2:
i_shape2 += [sh1.index(s) + i_of]
sh1.pop(sh1.index(s))
i_of += 1
shape2_adj[i_shape2] = sh2
arr2_broadcastable = arr2.reshape(tuple(shape2_adj))
return arr1, arr2_broadcastable
## Misc
thr_window = lambda X, alpha=0.4: X * (X > X.max() * alpha)
in_circle = lambda x, y, xc, yc, rc: (x - xc) ** 2 + (y - yc) ** 2 < rc**2
in_bbox = lambda y, z, bbox: y > bbox[1] and y < bbox[4] and z > bbox[2] and z < bbox[3]
def iterable_gen(obj, include_none=True, include_unitary=True):
it_ = np.iterable(obj) or (obj is None and include_none)
if np.iterable(obj):
it_ = len(obj) > 1 or include_unitary
return it_
def split_job_from_arrays(len_arrays, n_split, stype="default"):
"""
Split an array for parallel independant computing, by sharing independant sub-spaces \
of array index
Parameters
----------
len_arrays : int
length of the array containing the full job to perform in parallel
stype : str
method used to split the array:
"comb":
Returns
-------
mask : np.array
subspace of the array indexes, specific to each instantiation of the programm
"""
mask = np.arange(len_arrays)
if stype == "comb":
mask = [
list(np.where(i_split == mask % n_split)[0]) for i_split in range(n_split)
]
else:
mask = np.array_split(mask, n_split, axis=0)
mask = [list(m) for m in mask]
return mask
def rotate_axes(arr: np.ndarray, axis: int, target=0) -> np.ndarray:
"""_summary_
Parameters
----------
arr : np.ndarray
_description_
axis : int
_description_
target : int, optional
_description_, by default 0
Returns
-------
np.ndarray
_description_
Example
-------
>>> a = np.ones((3,6,4,7))
>>> a.shape
(3, 6, 4, 7)
>>> rotate_axes(a, axis=2).shape
(4, 3, 6, 7)
>>> rotate_axes(a, axis=2, target=-1).shape
(3, 6, 7, 4)
>>> rotate_axes(a, axis=1, target=-1).shape
(3, 4, 7, 6)
"""
if axis < 0:
axis = arr.ndim + axis
if target < 0:
target = arr.ndim + target
step = 1
if target > axis:
step = -1
for i_ax in range(target, axis, step):
arr = arr.swapaxes(i_ax, axis)
return arr
[docs]
def plot_array(
ax: plt.Axes, x: np.ndarray, arr: np.ndarray, axis: int | None = None, **kwgs
):
if arr.ndim == 1:
if len(x) != len(arr):
raise ValueError(
f"Dimension mismatch (x:{len(x)}, arr:{len(arr)}), array cannot be ploted"
)
return None
ax.plot(x, arr, **kwgs)
else:
if axis is not None:
if len(x) != arr.shape[axis]:
raise ValueError(
f"Dimension mismatch (x:{len(x)}, arr:{arr.shape[axis]}), array cannot be ploted"
)
else:
for a in range(arr.ndim):
if len(x) == arr.shape[a]:
axis = a
if axis is None:
raise ValueError(
f"No matching dim, (x:{len(x)}, arr:{arr.shape}), array cannot be ploted"
)
arr = rotate_axes(arr, axis=axis, target=-1)
for sub_arr in arr:
plot_array(ax, x=x, arr=sub_arr, axis=-1, **kwgs)
## Nerve conductivities methods
def compute_myelin_ppt(d, model="MRG", f=0):
"""
Extract the apparent conductivity of the myelin sheath for a given axon diameter.
Parameters
----------
d : float
Axon diameter in micrometers (µm).
model : str, optional
Model used for parameter extraction. Default is "MRG".
f : float, optional
Frequency in kilohertz (kHz). If greater than 0, the capacitive effect is included. Default is 0.
Returns
-------
sig_mye : complex or float
Apparent myelin conductivity in S/m². If frequency is specified, returns a complex value.
"""
g, axonD, nodeD, paraD1, paraD2, deltax, paralength2, nl = get_MRG_parameters(d)
mycm = 0.1 # uF/cm2/lamella membrane
mygm = 0.001 # S/cm2/lamella membrane
r_m = convert(d / 2, unitin="um", unitout="m")
mygm_S_m = convert(mygm, unitin="S/cm**2", unitout="S/m**2")
sig_mye = r_m * mygm_S_m / nl
if f > 0:
mycm_F_m = convert(mycm, unitin="uF/cm**2", unitout="F/m**2")
f_hz = convert(f, unitin="kHz", unitout="Hz")
sig_mye = sig_mye + 2j * np.pi * f_hz * r_m * mycm_F_m / nl
return sig_mye
def compute_y_wth_uv(y: np.ndarray, R_x: np.ndarray) -> tuple[np.ndarray]:
"""
Computes the equivalent admittance array using the provided admittance values `y` and resistance values `R_x`.
The function performs a calculation involving the transformation of admittance to impedance, then iteratively
computes two intermediate variables `u` and `v` for each element, which are used to determine the equivalent
admittance for each position in the input array.
Parameters
----------
y : np.ndarray
Array of admittance values.
R_x : np.ndarray
Array of resistance values. If not iterable, it is broadcasted to match the length of `y`.
Returns
-------
np.ndarray
Array of computed equivalent admittance values for each position.
"""
z = 1 / y
n_x = len(z)
if not np.iterable(R_x):
R_x *= np.ones(n_x)
z_eq = np.zeros(n_x)
u = np.zeros(n_x)
v = np.zeros(n_x)
u[0], v[0] = z[0] * 2, z[-1] * 2
z_n = deepcopy(z)
# TODO: find a way to do this with one loop
for n in range(n_x):
# z_n[k] := |z[k] for k=!n
# |2*z[k] for k==n
z_n[n] = 2 * z_n[n]
u_n = z_n[0]
if n >= 1:
for k in range(1, n):
# print(k)
u_n = z_n[k] * (R_x[k] + u_n) / (z_n[k] + R_x[k] + u_n)
v_n = z_n[-1]
if n <= n_x - 1:
for k in range(1, n_x - n + 1):
v_n = z_n[-k] * (R_x[-k] + v_n) / (z_n[-k] + R_x[-k] + v_n)
z_eq[n] = u_n * v_n / (u_n + v_n)
z_n[n] = z[n]
return 1 / z_eq
def compute_y_app(
y: np.ndarray, x_rec: np.ndarray, r_x: np.ndarray
) -> tuple[np.ndarray]:
"""
Approximates the values of `y` based on the provided `x_rec` and `r_x` arrays.
For each element in `y`, computes a weighted sum using the formula:
y_app[n] = sum(y / (1 + abs(x_rec[n] - x_rec) * r_x * y))
Warning
-------
For now the 2D approximation isn't well documented. Further explaination will be added to the doc in the future.
Parameters
----------
y : np.ndarray
Input array of values to be approximated.
x_rec : np.ndarray
Array of reference x positions.
r_x : np.ndarray
Array of scaling factors for each x position.
Returns
-------
y_app : np.ndarray
Array of approximated values, same length as `y`.
"""
n_x = len(y)
y_app = np.zeros(n_x)
for n in range(n_x):
kR_x = abs(x_rec[n] - x_rec) * r_x
y_app[n] = np.sum(y / (1 + kR_x * y))
return y_app
def compute_mye_sigma_2D(
sig_m_t: np.ndarray,
x_rec: np.ndarray,
sig_mye: float,
sig_in: float,
sig_out: float,
d_ax: float,
d_node: float,
alpha_th: float,
l_elec: float,
) -> float:
"""
Computes the apparent 2D myelin conductivity (sigma_2d) at a given time for a nerve fiber segment, taking into account the presence of nodes and their properties.
Warning
-------
For now the 2D approximation isn't well documented. Further explaination will be added to the doc in the future.
Parameters
----------
sig_m_t : np.ndarray
Array of node membrane conductivities at the given time for various locations.
x_rec : np.ndarray
Array of node positions along the fiber.
sig_mye : float
Constant myelin conductivity.
sig_in : float
Intracellular conductivity.
sig_out : float
Extracellular conductivity.
d_ax : float
Axon diameter.
d_node : float
Node diameter.
alpha_th : float
Threshold parameter for conductivity adjustment.
l_elec : float
Electrode length.
Returns
-------
sigma_2d : float
Apparent 2D myelin conductivity for the simulated FEM segment.
"""
d_node = float(sci_round(d_node, 5))
n_nodes = len(x_rec)
if n_nodes >= 1:
sig_nodes = get_sig_ap(sig_in, sig_m_t, alpha_th)
sig_nodes = get_sig_ap(sig_nodes, sig_out, (d_ax - d_node) / d_ax)
sig_nodes_2d = np.mean(sig_nodes)
frac_l_node = 0.1 * n_nodes / l_elec
sigma_2d = frac_l_node * sig_nodes_2d + (1 - frac_l_node) * sig_mye
else:
sigma_2d = sig_mye
return sigma_2d
def compute_sigma_2D(
Y_m_t: np.ndarray,
x_rec: np.ndarray,
sig_in: float,
sig_out: float,
d_ax: np.ndarray,
th_mem: float,
l_elec: float,
method="",
) -> float:
"""
Computes the apparent 2D conductivity (sigma_2D) at a given time of a membrane using admittance measurements and geometric parameters.
Parameters
----------
Y_m_t : np.ndarray
Array of node membrane conductivities at the given time for various locations.
x_rec : np.ndarray
Array of spatial positions (in micrometers) along the recording axis.
sig_in : float
Conductivity inside the membrane.
sig_out : float
Conductivity outside the membrane.
d_ax : np.ndarray
Array of membrane diameters (in micrometers).
th_mem : float
Membrane thickness (in meters).
l_elec : float
Electrode length (in micrometers).
method : str, optional
Method for computing admittance normalization. If contains "approx", uses an approximate method.
Returns
-------
sigma_2d : float
The computed 2D conductivity of the membrane.
"""
x_rec_m = convert(x_rec, unitin="um", unitout="m")
d_ax_m = convert(d_ax, unitin="um", unitout="m")
l_elec_m = convert(l_elec, unitin="um", unitout="m")
dx = np.diff(x_rec_m)
dx = np.append(dx, dx[0])
Y_mem_x = Y_m_t * np.pi * d_ax_m * dx / th_mem
r_x = 1 / sig_out + 1 / sig_in
if "approx" in method:
# print("hello app", method)
Y_mem_n = compute_y_app(y=Y_mem_x, x_rec=x_rec_m, r_x=r_x)
else:
# print("hello ex", method)
R_x = r_x * dx
Y_mem_n = compute_y_wth_uv(y=Y_mem_x, R_x=R_x)
l_fem = x_rec[-1]
e_mask = np.argwhere(
(x_rec > (l_fem - l_elec) / 2) & (x_rec < (l_fem + l_elec) / 2)
)
sigma_2d = np.mean(Y_mem_n[e_mask]) * th_mem / (np.pi * d_ax_m * l_elec_m)
# print(dx[0], Y_mem_x[0], sigma_2d)
return sigma_2d
def compute_sigma_2D_old(
Y_m_t: np.ndarray, x_rec: np.ndarray, sig_in: float, sig_out: float, l_elec: float
) -> np.ndarray:
# print("hello old")
n_x = len(x_rec)
I = np.arange(n_x)
Y_m_eq = np.zeros(n_x)
x_rec_m = convert(x_rec, unitin="um", unitout="m")
# dx = np.diff(x_rec)[0]
r_x = 1 / sig_out + 1 / sig_in
for n in range(n_x):
G_n = abs(x_rec_m[n] - x_rec_m) * r_x
Y_m_eq[n] = np.sum(Y_m_t / (1 + G_n * Y_m_t)) / n_x
Y_ = np.mean(Y_m_eq)
return Y_
def sum_sigma_ax(results: nerve_results) -> np.ndarray:
"""
Computes the sum of the mean membrane conductivity across all axons in the given nerve results.
Parameters
----------
results : nerve_results
An object containing the results of nerve simulations, including axon population properties and methods to retrieve individual axon results.
Returns
-------
np.ndarray
The summed mean membrane conductivity across all axons, as a NumPy array.
"""
_axons_pop_ppts = results.axons_pop_properties
sy_mem_t = None
for i_ax in range(results.n_ax):
_ax_ppts = _axons_pop_ppts[i_ax, :]
ax_res = results.get_axon_results(_ax_ppts[0], _ax_ppts[1])
if sy_mem_t is None:
sy_mem_t = np.mean(ax_res.get_membrane_conductivity(), axis=0)
else:
sy_mem_t += np.mean(ax_res.get_membrane_conductivity(), axis=0)
return sy_mem_t
## Additionnal On the flight posporoc functions
def sample_keys_mdt(
results: axon_results,
keys_to_sample: str | set[str] = {},
sample_dt: list | None | float = None,
t_start_rec: float = 0,
t_stop_rec: float = -1,
i_sampled_t: None | np.ndarray = None,
x_bounds: None | float | tuple[float] = None,
keys_to_remove: str | set[str] = set(),
keys_to_keep: set[str] = set(),
) -> axon_results:
"""
extension of sample_key axon postproc function from nrv allowing to simply set an addaptative dt
Note
----
sample_dt shloud be a list of `tuple` each containing a value of dt (`dt_seg`) and a time (`t_swich`) at which the dt should switch to the next with the following formalism: (t_swich, dt_seg)
The last t_swich value should be -1 to be set to t_stop_rec
Parameters
----------
results : axon_results
results of the axon simulation.
t_start_rec : float, optional
Lower time at whitch `g_mem` should be stored, by default 0
t_stop_rec : float, optional
Upper time at whitch `g_mem` should be stored, by default -1
sample_dt : None | float, optional
Time sample rate at which `g_mem` should be stored if None simulation dt is kept, by default None
x_bounds : None | tuple[float], optional
x-positions where to store `g_mem`, possible option:
- float: The values of `g_mem` are only stored at the nearest position in `x_rec`.
- tupple: `g_mem` values are stored for all positions included between the two boundaries.
- None (default): `g_mem` values are stored for all positions.
"""
results.is_recruited()
keys_to_keep = keys_to_keep.union({"recruited"})
if not np.iterable(sample_dt):
return sample_keys(
results=results,
keys_to_sample=keys_to_sample,
t_start_rec=t_start_rec,
t_stop_rec=t_stop_rec,
sample_dt=sample_dt,
i_sampled_t=i_sampled_t,
x_bounds=x_bounds,
keys_to_remove=keys_to_remove,
keys_to_keep=keys_to_keep,
)
if t_stop_rec < 0:
i_t_max = len(results["t"])
else:
i_t_max = np.argwhere(results["t"] <= t_stop_rec)[-1][0]
i_t_min = np.argwhere(results["t"] >= t_start_rec)[0][0]
t_APs = []
i_t_start = i_t_min
for t_switch, cur_dt in sample_dt:
if t_switch > 0:
i_switch_dt = np.argwhere(results["t"] <= t_switch)[-1][0]
else:
i_switch_dt = i_t_max
t_APs += [k for k in range(i_t_start, i_switch_dt, int(cur_dt / results.dt))]
i_t_start = deepcopy(i_switch_dt)
i_sampled_t = np.array(t_APs)
return sample_keys(
results=results,
keys_to_sample=keys_to_sample,
t_start_rec=t_start_rec,
t_stop_rec=t_stop_rec,
sample_dt=None,
i_sampled_t=i_sampled_t,
x_bounds=x_bounds,
keys_to_remove=keys_to_remove,
keys_to_keep=keys_to_keep,
)
[docs]
def get_samples_index(
results: nerve_results,
n_pts: int,
alpha: float = 0.001,
t_iclamp: float = 1,
d_iclamp: float = 0.2,
n_pts_min=None,
) -> np.ndarray:
"""
Selects sample indices from nerve simulation results to achieve adaptative sampling along the arc length of the signal.
Note
----
For now, the sampling indexes are computed from the analytical recorder's variation, i.e. proximity between indexes is proportionnal to the time derivative of the recoeder's values.
Warning
-------
In future version, the previous note might be extended to the global variation of conductivity in axons' membrane in the nerve (instead of only analytical recorder).
Parameters
----------
results : nerve_results
The results object containing simulation recordings and time points.
n_pts : int
The desired number of sample points.
alpha : float, optional
Regularization parameter for arc length calculation (default is 0.001).
t_iclamp : float, optional
Start time of the current clamp artifact to be removed (default is 1).
d_iclamp : float, optional
Duration of the current clamp artifact to be removed (default is 0.2).
n_pts_min : int, optional
Minimum number of sample points to return. If not specified, defaults to `n_pts`.
Returns
-------
np.ndarray
Array of indices corresponding to selected sample points, distributed homogeneously along the arc length of the signal.
Note
----
- Removes the effect of current clamp artifact from the signal before sampling.
"""
# TODO: From sum of gmem instead of recorder
if "recorder" in results:
if n_pts_min is None:
n_pts_min = n_pts
t = np.array(results.recorder.t)
t_sim = t[-1]
dt = t[1] - t[0]
v = np.array(results.recorder.recording_points[0].recording)
# removing change due to iclamp
v[int(t_iclamp / dt) : int((t_iclamp + d_iclamp + 0.01) / dt)] *= 0
# v = sum_sigma_ax(results)
# plt.plot(t, v)
# plt.savefig("test.png")
# Normalize axes
norm_dt = dt / t_sim
norm_v = abs(v)
norm_v /= max(norm_v)
#
dv = np.diff(v, prepend=v[0])
drec_dt = (dv**2 + (alpha * norm_dt) ** 2) ** 0.5
Sdv_dt = np.cumsum(drec_dt)
lenght_arc = Sdv_dt[-1]
length_sample = lenght_arc / (n_pts - 1)
# Sample homogeneously along arc length
i_t_samples = np.array(
[np.argmin(abs(Sdv_dt - (k * length_sample))) for k in range(n_pts)]
)
# Remove repeted indexes
ok_mask = np.append(np.diff(i_t_samples) != 0, [True])
i_t_samples = i_t_samples[ok_mask]
if len(i_t_samples) < n_pts_min:
i_t_samples = get_samples_index(
results=results,
n_pts=n_pts + 1,
alpha=alpha,
t_iclamp=t_iclamp,
d_iclamp=d_iclamp,
n_pts_min=n_pts_min,
)
return i_t_samples
[docs]
def sample_nerve_results(
results: nerve_results,
n_pts: int,
alpha: float = 0.001,
t_iclamp: float = 1,
d_iclamp: float = 0.2,
keys_to_sample="g_mem",
) -> nerve_results:
"""
Samples specific keys from nerve simulation results at selected time points.
Note
----
By contrast with the :func:`sample_keys`-function, this one must be call after the nerve simulation on the whole :class:`nerve_results`.
Parameters
----------
results : nerve_results
The nerve simulation results object containing axon population properties and results.
n_pts : int
Number of time points to sample from the results.
alpha : float, optional
Threshold parameter used for sample index selection (default is 0.001).
t_iclamp : float, optional
Time of current clamp onset in milliseconds (default is 1).
d_iclamp : float, optional
Duration of current clamp in milliseconds (default is 0.2).
keys_to_sample : str or list of str, optional
Keys of the results to sample (default is "g_mem").
Returns
-------
nerve_results
The updated nerve_results object with sampled keys at selected time points.
Note
----
This function modifies the input `results` object in place by sampling the specified keys for each axon at the selected time indices.
"""
import pandas as pd
i_t_fem = get_samples_index(
results, n_pts, alpha=alpha, d_iclamp=d_iclamp, t_iclamp=t_iclamp
)
_axons_pop_ppts: pd.DataFrame = results.axons
for i_ax in _axons_pop_ppts.index:
_ax_ppts = _axons_pop_ppts.loc[i_ax]
ax_res = results[_ax_ppts["fkey"]][_ax_ppts["akey"]]
ax_res = sample_keys(ax_res, keys_to_sample=keys_to_sample, i_sampled_t=i_t_fem)
return results
## Post-processing
def compute_v_rec_cap_idxs(
Voltage, dt, t_stim=0, stim_duration=0.2, tol=0.05, use_filter=True
):
i_offset_stim = int((t_stim + stim_duration) / dt) + 1
_v_rec_nrm = Voltage[i_offset_stim:].copy()
_v_rec_nrm /= -_v_rec_nrm.min()
i_t_m = np.argwhere(_v_rec_nrm < -tol).squeeze()
if len(i_t_m) == 0:
print("No mylinated cap detected")
i_cap_m = 0, 0, 0, 0
else:
di_t_m = np.diff(i_t_m[:-1], prepend=-1, append=0)
i_cut = np.squeeze(np.where(di_t_m != 1)) - 1
if not np.iterable(i_cut):
i_start_m, i_stop_m = i_t_m[0], i_t_m[i_cut]
else:
if i_cut[0] != 0:
i_start_m, i_stop_m = i_t_m[0], i_t_m[i_cut[0]]
else:
i_start_m, i_stop_m = i_t_m[i_cut[:2]]
i_start_m += i_offset_stim
i_stop_m += i_offset_stim
i_t_min, i_t_max = (
np.argmin(Voltage[i_start_m:i_stop_m]) + i_start_m,
np.argmax(Voltage[i_start_m:i_stop_m]) + i_start_m,
)
i_cap_m = i_start_m, i_t_min, i_t_max, i_stop_m
i_start = max(i_cap_m[-1], i_offset_stim)
i_start += np.argwhere(Voltage[i_cap_m[-1] :] > 0)[0][0]
dv_rec_dt = np.diff(Voltage[i_start:])
if use_filter:
dv_rec_dt = savgol_filter(dv_rec_dt, 1000, 3)
dv_rec_dt /= np.max(abs(dv_rec_dt))
else:
dv_rec_dt /= np.max(abs(dv_rec_dt))
i_t_u = np.argwhere(dv_rec_dt > 0.1).squeeze() + i_start
i_t_start, i_t_stop = i_t_u[0], i_t_u[-1]
i_t_min, i_t_max = (
np.argmin(Voltage[i_t_start:i_t_stop]) + i_t_start,
np.argmax(Voltage[i_t_start:i_t_stop]) + i_t_start,
)
i_cap_u = i_t_start, i_t_min, i_t_max, i_t_stop
return i_cap_m, i_cap_u