"""
NRV-Cellular Level postprocessing.
.. warning::
CL_postprocessing functions are deprecated. Use axon_results and associated methods instead.
"""
import faulthandler
from collections.abc import Iterable
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from ..backend._log_interface import pass_info, rise_warning, rise_error
from ..utils._units import MHz
from ..utils._misc import distance_point2line
from ..nmod._unmyelinated import unmyelinated
from ..nmod._myelinated import myelinated
from ..nmod.results._axons_results import axon_results
# enable faulthandler to ease 'segmentation faults' debug
faulthandler.enable()
##############################################
## HANDLE THE SIMULATION RESULT DICTIONNARY ##
##############################################
[docs]
def remove_key(my_dict, key, verbose=False):
"""
Remove an item from a dictionary, usefull before saving files, as some results maybe heavy and are potentially useless after some steps of postprocessing.
Parameters
----------
my_dict : dictionary
dictionary where an item should be deleted
key : str
name of the key to delete
"""
rise_warning(
"DeprecationWarning: ",
"remove_key is obsolete use method from axon_result objects instead",
)
# if isinstance(key, Iterable):
# for k in key:
# del my_dict[k]
# else:
del my_dict[key]
pass_info("removed the following key from results: ", key, verbose=verbose)
[docs]
def remove_non_NoR_zones(my_dict, key):
"""
Automatically remove values out of nodes of Ranvier for membrane voltage and associated quantities.
This function is helpfull for large simulation before saving results
Parameters
----------
my_dict : dictionary
dictionary where the quantity should be cleaned
key : str
name of the key to clean
"""
rise_warning(
"DeprecationWarning: ",
"remove_non_NoR_zones is obsolete use method from axon_result objects instead",
)
if "V_mem" in key:
if my_dict["Axon_type"] == "Myelinated":
new_entry = []
for i in my_dict["Nodes_of_Ranvier_indexes"]:
new_entry.append(my_dict[key][i, :])
my_dict[key] = np.asarray(new_entry)
else:
rise_warning(
"Warning, remove_non_NoR_zones only applicable to Myelinated axons"
)
else:
rise_warning(
"Warning, remove_non_NoR_zones only applicable to membrane voltage or current"
)
[docs]
def generate_axon_from_results(results):
"""
generate a blank axon with the same dimensions as in the results
Parameters
----------
results_sim : dict
results of axon.simulate method
Returns
-------
ax : unmyelinated or myelinated
blank axon with the same dimensions as in the results
"""
# rise_warning(
# "DeprecationWarning: ",
# "generate_axon_from_results property is obsolete use method from axon_result objects instead"
# )
if not results["myelinated"]:
ax = unmyelinated(**results)
else:
ax = myelinated(**results)
return ax
############################
## AXON SIGNAL PROCESSING ##
############################
[docs]
def filter_freq(my_dict, my_key, freq, Q=10):
"""
Basic Filtering of quantities. This function design a notch filter (scipy IIR-notch).
Adds an item to the specified dictionary, with the key termination '_filtered' concatenated to the original key.
Parameters
----------
my_dict : dictionary
dictionary where the quantity should be filtered
key : str
name of the key to filter
freq : float or array, list, np.array
frequecy or list of frequencies to filter in kHz, as time is defined in ms in NRV2.
If multiple frequencies, they are filtered sequencially, with as may filters as frequencies, in the specified order
Q : float
quality factor of the filter, by default set to 10
"""
rise_warning(
"DeprecationWarning: ",
"filter_freq is obsolete use method from axon_result objects instead",
)
if isinstance(freq, Iterable):
f0 = np.asarray(freq)
else:
f0 = freq
if my_dict["dt"] == 0:
rise_warning(
"Warning: filtering aborted, variable time step used for differential equation solving"
)
return False
else:
fs = 1 / my_dict["dt"]
if isinstance(f0, Iterable):
new_sig = np.zeros(my_dict[my_key].shape)
for k in range(len(my_dict[my_key])):
new_sig[k, :] = my_dict[my_key][k]
for f in f0:
b_notch, a_notch = signal.iirnotch(f, Q, fs)
new_sig[k, :] = signal.lfilter(b_notch, a_notch, new_sig[k, :][k])
else:
## NOTCH at the stimulation frequency
b_notch, a_notch = signal.iirnotch(f0, Q, fs)
new_sig = np.zeros(my_dict[my_key].shape)
for k in range(len(my_dict[my_key])):
new_sig[k, :] = signal.lfilter(b_notch, a_notch, my_dict[my_key][k])
my_dict[my_key + "_filtered"] = new_sig
[docs]
def rasterize(
my_dict, my_key, t_start=0, t_stop=0, t_min_spike=0.1, t_refractory=2, threshold=0
):
"""
Rasterize a membrane potential (or filtered or any quantity processed from membrane voltage), with spike detection.
This function adds 4 items to the dictionnary, with the key termination '_raster_position', '_raster_x_position', '_raster_time_index', '_raster_time' concatenated to the original key.
These keys correspond to:
_raster_position : spike position as the indice of the original key
_raster_x_position : spike position as geometrical position in um
_raster_time_index : spike time as the indice of the original key
_raster_time : spike time as ms
Parameters
----------
my_dict : dictionary
dictionary where the quantity should be rasterized
key : str
name of the key to rasterize
t_start : float
time at which the spike detection should start, in ms. By default 0
t_stop : float
maximum time to apply spike detection, in ms. If zero is specified, the spike detection is applied to the full signal duration. By default set to 0.
t_min_spike : float
minimum duration of a spike over its threshold, in ms. By default set to 0.1 ms
t_refractory : float
refractory period for a spike, in ms. By default set to 2 ms.
threshold : float
threshold for spike dection, in mV. If 0 is specified the threshold associated with the axon is automatically chosen. By default set to 0.
Note that if a 0 value is wanted as threshold, a insignificat value (eg. 1e-12) should be specified.
"""
rise_warning(
"DeprecationWarning: ",
"rasterize is obsolete use method from axon_result objects instead",
)
if t_stop == 0:
t_stop = int(my_dict["t_sim"] / my_dict["dt"])
else:
t_stop = int(t_stop / my_dict["dt"])
if threshold == 0:
thr = my_dict["threshold"]
else:
thr = threshold
## selecting the list of position considering what has been recorded
if my_dict["myelinated"] == True:
if my_dict["rec"] == "all":
list_to_parse = my_dict["node_index"]
x = my_dict["x"]
else:
list_to_parse = np.arange(len(my_dict["x_rec"])) # my_dict[my_key]
x = my_dict["x_rec"]
else:
list_to_parse = np.arange(len(my_dict["x_rec"])) # my_dict[my_key]
x = my_dict["x_rec"]
# spike detection
(
my_dict[my_key + "_raster_position"],
my_dict[my_key + "_raster_x_position"],
my_dict[my_key + "_raster_time_index"],
my_dict[my_key + "_raster_time"],
) = AP_detection(
my_dict[my_key],
my_dict["t"],
x,
list_to_parse,
thr,
my_dict["dt"],
t_start,
t_stop,
t_refractory,
t_min_spike,
)
# @jit(nopython=True, fastmath=True)
[docs]
def AP_detection(
Voltage, t, x, list_to_parse, thr, dt, t_start, t_stop, t_refractory, t_min_spike
):
"""
Internal use only, spike detection just in time compiled to speed up the process
"""
rise_warning(
"DeprecationWarning: ",
"AP_detection is obsolete use method from axon_result objects instead",
)
raster_position = []
raster_x_position = []
raster_time_index = []
raster_time = []
# parsing to find spikes
for i in list_to_parse:
t_last_spike = t_start - t_refractory
for j in range(int(t_start * (1 / dt)), t_stop):
if (
Voltage[i][j] <= thr
and Voltage[i][j + 1] >= thr
and Voltage[i][min((j + int(t_min_spike * (1 / dt))), t_stop)] >= thr
and (j * dt - t_last_spike) > t_refractory
): # 1st line: threshold crossing, 2nd: minimum time above threshold,3rd: refractory period
# there was a spike, get time and position
raster_position.append(i)
raster_x_position.append(x[i])
raster_time_index.append(j)
raster_time.append(t[j])
# memorize the time in ms, to evaluate refractory period
t_last_spike = j * dt
# return results
return (
np.asarray(raster_position),
np.asarray(raster_x_position),
np.asarray(raster_time_index),
np.asarray(raster_time),
)
[docs]
def speed(my_dict, position_key=None, t_start=0, t_stop=0, x_start=0, x_stop=0):
"""
Compute the velocity of a spike from rasterized data in a dictionary. The speed can be either positive or negative depending on the propagation direction.
Parameters
----------
my_dict : dictionary
dictionary where the quantity should be rasterized
key : str
name of the key to consider, if None is specified, the rasterized is automatically chose with preference for filtered-rasterized keys.
t_start : float
time at which the spikes are processed, in ms. By default 0
t_stop : float
maximum time at which the spikes are processed, in ms. If zero is specified, the spike detection is applied to the full signal duration. By default set to 0.
x_start : float
minimum position for spike processing, in um. By default set to 0.
x_stop : float
maximum position for spike processing, in um. If 0 is specified, spikes are processed on the full axon length . By default set to 0.
Returns
-------
speed : float
velocity.
Note
----
the velocity is computed with first and last occurance of a spike, be careful specifying the computation window if multiple spikes. This function will not see velocity variation.
"""
rise_warning(
"DeprecationWarning: ",
"speed is obsolete use method from axon_result objects instead",
)
# define max timing if not already defined
if t_stop == 0:
t_stop = my_dict["t_sim"]
if t_start == 0:
if "intra_stim_starts" in my_dict and my_dict["intra_stim_starts"] != []:
t_start = my_dict["intra_stim_starts"][0]
if x_start == 0:
x_stop = my_dict["L"]
elif x_stop == 0:
x_start = my_dict["L"]
# find the best raster plot
if position_key == None:
if "V_mem_filtered_raster_position" in my_dict:
good_key_prefix = "V_mem_filtered_raster"
elif "V_mem_raster_position" in my_dict:
good_key_prefix = "V_mem_raster"
else:
# there is no rasterized voltage, nothing to find
return False
else:
good_key_prefix = position_key
# get data only in time windows
sup_time_indexes = np.where(my_dict[good_key_prefix + "_time"] > t_start)
inf_time_indexes = np.where(my_dict[good_key_prefix + "_time"] < t_stop)
good_indexes_time = np.intersect1d(sup_time_indexes, inf_time_indexes)
sup_position_indexes = np.where(
my_dict[good_key_prefix + "_x_position"][good_indexes_time] >= x_start
)
inf_position_indexes = np.where(
my_dict[good_key_prefix + "_x_position"][good_indexes_time] <= x_stop
)
good_indexes_position = np.intersect1d(sup_position_indexes, inf_position_indexes)
good_indexes = np.intersect1d(good_indexes_position, good_indexes_time)
good_spike_times = my_dict[good_key_prefix + "_time"][good_indexes]
good_spike_positions = my_dict[good_key_prefix + "_x_position"][good_indexes]
max_time = np.argmax(good_spike_times)
min_time = np.argmin(good_spike_times)
speed = (
(good_spike_positions[max_time] - good_spike_positions[min_time])
* 10**-3
/ (good_spike_times[max_time] - good_spike_times[min_time])
)
return speed
[docs]
def block(my_dict, position_key=None, t_start=0, t_stop=0):
"""
check if an axon is blocked or not. The simulation has to include the test spike. This function will look for the test spike initiation and check the propagation
Parameters
----------
my_dict : dictionary
dictionary where the quantity should be rasterized
key : str
name of the key to consider, if None is specified, the rasterized is automatically chose with preference for filtered-rasterized keys.
t_start : float
time at which the test spikes can occur, in ms. By default 0
t_stop : float
maximum time at which the spikes are processed, in ms. If zero is specified, the spike detection is applied to the full signal duration. By default set to 0.
Returns
-------
flag : bool or None
True if the axon is blocked, False if not blocked and None if the test spike does not cross the stimulation near point in the simulation (no possibility to check for the axon state)
"""
rise_warning(
"DeprecationWarning: ",
"block is obsolete use method from axon_result objects instead",
)
position_max = 0
blocked_spike_positionlist = []
if t_stop == 0:
t_stop = my_dict["t_sim"]
if t_start == 0:
if "intra_stim_starts" in my_dict and my_dict["intra_stim_starts"] != []:
t_start = my_dict["intra_stim_starts"][0]
if position_key == None:
if "V_mem_filtered_raster_position" in my_dict:
good_key_prefix = "V_mem_filtered_raster"
elif "V_mem_raster_position" in my_dict:
good_key_prefix = "V_mem_raster"
else:
# there is no rasterized voltage, nothing to find
return False
sup_time_indexes = np.where(my_dict[good_key_prefix + "_time"] > t_start)
inf_time_indexes = np.where(my_dict[good_key_prefix + "_time"] < t_stop)
good_indexes_time = np.intersect1d(sup_time_indexes, inf_time_indexes)
good_spike_times = my_dict[good_key_prefix + "_time"][good_indexes_time]
blocked_spike_positionlist = my_dict[good_key_prefix + "_x_position"][
good_indexes_time
]
if len(blocked_spike_positionlist) == 0:
return None
if "intra_stim_position" in my_dict:
if my_dict["intra_stim_position"] < my_dict["extracellular_electrode_x"]:
position_max = max_spike_position(
blocked_spike_positionlist, position_max, spike_begin="down"
)
if blocked_spike_positionlist[position_max] < 9.0 / 10 * my_dict["L"]:
return True
else:
for i in range(position_max - 1):
if (
blocked_spike_positionlist[i + 1]
- blocked_spike_positionlist[i]
> my_dict["L"] / 5
):
return True
else:
return False
else:
position_max = max_spike_position(
blocked_spike_positionlist, position_max, spike_begin="up"
)
if min(blocked_spike_positionlist) > 1.0 / 10 * my_dict["L"]:
return True
else:
for i in range(position_max - 1):
if (
blocked_spike_positionlist[i]
- blocked_spike_positionlist[i + 1]
> my_dict["L"] / 5
):
return True
else:
return False
else:
pass_info("intra_stim_position is not in dictionnary")
[docs]
def max_spike_position(blocked_spike_positionlist, position_max, spike_begin="down"):
"""
Extend a blocked-spike position index to the corresponding local extremum.
Parameters
----------
blocked_spike_positionlist : list[float] | np.ndarray
Ordered list of blocked-spike positions.
position_max : int
Initial index of the spike position under inspection.
spike_begin : str, optional
Direction used to follow the spike front, either ``"down"`` or the
opposite direction.
Returns
-------
int
Index of the detected local extremum in ``blocked_spike_positionlist``.
"""
rise_warning(
"DeprecationWarning: ",
"max_spike_position is obsolete use method from axon_result objects instead",
)
if spike_begin == "down":
while blocked_spike_positionlist[
position_max + 1
] >= blocked_spike_positionlist[position_max] and position_max < (
len(blocked_spike_positionlist) - 2
):
position_max = position_max + 1
return position_max
else:
while blocked_spike_positionlist[
position_max + 1
] <= blocked_spike_positionlist[position_max] and position_max < (
len(blocked_spike_positionlist) - 2
):
position_max = position_max + 1
return position_max
# @jit(nopython=True, fastmath=True)
[docs]
def count_spike(onset_position):
"""
spike counting, just in time compiled. For internal use only.
"""
rise_warning(
"DeprecationWarning: ",
"count_spike is obsolete use method from axon_result objects instead",
)
if len(onset_position) == 0:
spike_number = 0
return 0
else:
spike_number = 1
for i in range(len(onset_position) - 1):
if onset_position[i] == min(onset_position):
if onset_position[i] == onset_position[i + 1]:
spike_number = spike_number + 1
return spike_number
[docs]
def check_test_AP(results_sim):
"""
Check if a test AP is correctely triggered during an axon simulation and if so return the\
trigger time
Parameters
----------
results_sim : dict
results of axon.simulate method
Returns
-------
test_AP : float or None
time in ms of the first test AP during the simulation. None if no test AP found
"""
rise_warning(
"DeprecationWarning: ",
"check_test_AP is obsolete use method from axon_result objects instead",
)
if type(results_sim) == str:
results_sim = load_simulation_from_json(results_sim)
if "intra_stim_starts" not in results_sim:
return None
else:
mask = False
test_AP = results_sim["intra_stim_starts"]
if len(test_AP):
if is_iterable(test_AP):
test_AP = test_AP[0]
i_first_pos = np.where(results_sim["V_mem_raster_x_position"] == 0)
for i in i_first_pos[0]:
if (
results_sim["V_mem_raster_time"][i] >= test_AP - 0.01
and results_sim["V_mem_raster_time"][i] <= test_AP + 0.7
):
mask = True
if not mask:
test_AP = None
else:
test_AP = None
return test_AP
[docs]
def axon_state(results_sim, save=False, saving_file="axon_state.json"):
"""
Return axon caracteristics (blocked, Onset response, ...) from axon simulation results
Parameters
----------
results_sim : dict or str
simulation results dictionary or path and name of the saving file
save : bool
if True save result in json file
saving_file : str
if save is True path and name of the saving file
Returns
-------
axon_state : dict
dictionary containing axon caracteristics
"""
rise_warning(
"DeprecationWarning: ",
"axon_state is obsolete use method from axon_result objects instead",
)
if type(results_sim) == str:
results_sim = load_simulation_from_json(results_sim)
ID = results_sim["ID"]
# Axon parameter
parameters = {}
if "diameter" in results_sim:
parameters["diameter"] = results_sim["diameter"]
if "myelinated" in results_sim and results_sim["myelinated"]:
parameters["node"] = len(results_sim["x_nodes"])
if (
"extracellular_electrode_y" in results_sim
and len(results_sim["extracellular_electrode_y"]) == 1
):
parameters["distance electrod"] = distance_point2line(
results_sim["y"],
results_sim["z"],
results_sim["extracellular_electrode_y"][0],
results_sim["extracellular_electrode_z"][0],
)
if results_sim["myelinated"]:
x_elec = results_sim["extracellular_electrode_x"][0]
elec_node = np.argmin(abs(results_sim["x_nodes"] - x_elec))
elec_ali = (results_sim["x_nodes"][elec_node] - x_elec) / (
results_sim["x_nodes"][1] - results_sim["x_nodes"][0]
)
parameters["electrod node"] = int(elec_node)
parameters["electrod alignment"] = float(elec_ali)
# Check Block
test_AP = check_test_AP(results_sim)
if test_AP is None:
block_state = None
else:
if "extracellular_electrode_x" not in results_sim:
results_sim["extracellular_electrode_x"] = 0
block_state = block(
results_sim, t_start=test_AP - 0.001
) # , t_stop=test_AP+1) # Gerer le delay
# Check Onset Response
onset_state = False
t_start_stim = detect_start_extrastim(results_sim)
pos = results_sim["V_mem_raster_position"]
if results_sim["myelinated"]:
M = len(["x_nodes"])
else:
M = len(["x_rec"])
i_first_pos = np.where(pos == 0)
i_last_pos = np.where(pos == M)
# Count Onset response
N_AP = (len(i_first_pos[0]) + len(i_last_pos[0])) / 2
if test_AP is not None:
if block_state:
N_AP -= 0.5
else:
N_AP -= 1
if N_AP > 0:
onset_state = True
axon_state = {
"ID": ID,
"parameters": parameters,
"block_state": block_state,
"onset_state": onset_state,
"onset number": N_AP,
}
if save:
save_axon_results_as_json(axon_state, saving_file)
return axon_state
##############################
## Axon properties function ##
##############################
[docs]
def get_index_myelinated_sequence(results, n):
"""
Returns the sequence
Parameters
----------
results_sim : dict
results of axon.simulate method
Returns
-------
n_center : int
value of the cutoff frequency of the axon's membrane
"""
rise_warning(
"DeprecationWarning: ",
"get_index_myelinated_sequence is obsolete use method from axon_result objects instead",
)
if not results["myelinated"] or results["rec"] == "node":
return "node"
else:
if n > len(results["x_rec"]):
rise_warning("index not in axon")
# +1 required because nbr of computation point = nbr seg/sec + 1
# see if it's a bug
Nseg_per_sec = results["Nseg_per_sec"] + 1
N_sec_type = 11
seq_types = results["axon_path_type"]
if n == 0:
return seq_types[0]
else:
return seq_types[((n - 1) // Nseg_per_sec) % N_sec_type]
[docs]
def find_central_node_index(results: axon_results) -> int:
"""
Returns the index of the closer node from the center from a dictionnary results
Parameters
----------
axon_results : dict
results of axon.simulate method
Returns
-------
n_center : int
value of the cutoff frequency of the axon's membrane
"""
rise_warning(
"DeprecationWarning: ",
"find_central_node_index is obsolete use method from axon_result objects instead",
)
return results.find_central_index()
[docs]
def compute_f_mem(results):
"""
compute the cutoff frequency of the axon's membrane and add it to the simulation results dictionnary
NB: The frequency is computed in [kHz]
Parameters
----------
results_sim : dict
results of axon.simulate method
Returns
-------
f_mem : np.ndarray
value of the cutoff frequency of the axon's membrane
"""
rise_warning(
"DeprecationWarning: ",
"compute_f_mem is obsolete use method from axon_result objects instead",
)
if "g_mem" not in results:
rise_warning("f_mem cannot be computed computed without membrane conductivity")
return None
if "f_mem" not in results:
ax = generate_axon_from_results(results)
results["c_mem"] = ax.get_membrane_capacitance()
del ax
N_x, N_t = np.shape(results["g_mem"])
f_mem = np.zeros((N_x, N_t))
for i_t in range(N_t):
f_mem[:, i_t] = results["g_mem"][:, i_t] / (2 * np.pi * results["c_mem"])
# in [MHz] as g_mem in [S/cm^{2}] and c_mem [uF/cm^{2}]
# * [MHz] to convert to [kHz]
results["f_mem"] = f_mem * MHz
return results["f_mem"]
[docs]
def get_myelin_properties(results):
"""
compute the cutoff frequency of the axon's membrane and add it to the simulation results dictionnary
NB: The frequency is computed in [kHz]
Parameters
----------
results_sim : dict
results of axon.simulate method
Returns
-------
f_mem : np.ndarray
value of the cutoff frequency of the axon's membrane
"""
rise_warning(
"DeprecationWarning: ",
"get_myelin_properties is obsolete use method from axon_result objects instead",
)
if not results["myelinated"] or results["rec"] == "node":
rise_warning("No myelin in the axon simulated, None returned")
return None
ax = generate_axon_from_results(results)
results["g_mye"] = ax.get_myelin_conductance()
results["c_mye"] = ax.get_myelin_capacitance()
results["f_mye"] = results["g_mye"] / (2 * np.pi * results["c_mye"])
# in [MHz] as g_mem in [S/cm^{2}] and c_mem [uF/cm^{2}]
# * [MHz] to convert to [kHz]
results["f_mem"] *= MHz
return results["f_mem"]
#############################
## VISUALIZATION FUNCTIONS ##
#############################
[docs]
def plot_Nav_states(ax, values, title=""):
"""
Plot the state machine for kinetic (Markov) Nav 1.1 to 1.9 values
Parameters
----------
ax : matplotlib axis object
axes of the figure to work on
values : list, array, numpy array
title : str, optional
Title of the plot, by default ""
"""
rise_warning(
"DeprecationWarning: ",
"plot_Nav_states is obsolete use method from axon_result objects instead",
)
states = [r"$I_1$", r"$I_2$", r"$C_1$", r"$C_2$", r"$O_1$", r"$O_2$"]
X = [-1, -3, 0, 1, 0, 3]
Y = [0, 0, 1, 0, -1, 0]
c = ["r", "r", "b", "b", "g", "g"]
ax.set_xlim(-3.4, 3.4)
ax.set_ylim(-1.5, 1.5)
for i in range(len(states)):
ax.scatter(X[i], Y[i], s=300 + values[i] * 1450, c=c[i], alpha=0.4)
ax.text(X[i], Y[i], states[i], ha="center", va="center")
# paths
ax.arrow(
-2.5, 0.03, 1, 0, linewidth=1, alpha=0.5, head_width=0.02, head_length=0.02
)
ax.text(-2, 0.2, "$I_2I_1$", ha="center", va="center", alpha=0.4)
ax.arrow(
-1.5, -0.03, -1, 0, linewidth=1, alpha=0.5, head_width=0.02, head_length=0.02
)
ax.text(-2, -0.2, "$I_1I_2$", ha="center", va="center", alpha=0.4)
ax.arrow(
-0.83, 0.25, 0.5, 0.5, linewidth=1, alpha=0.5, head_width=0.02, head_length=0.02
)
ax.text(-0.9, 0.6, "$I_1C_1$", ha="center", va="center", alpha=0.4)
ax.arrow(
-0.22,
0.75,
-0.5,
-0.5,
linewidth=1,
alpha=0.5,
head_width=0.02,
head_length=0.02,
)
ax.text(-0.45, 0.25, "$C_1I_1$", ha="center", va="center", alpha=0.4)
ax.arrow(
0.72, 0.25, -0.5, 0.5, linewidth=1, alpha=0.5, head_width=0.02, head_length=0.02
)
ax.text(0.9, 0.6, "$C_1C_2$", ha="center", va="center", alpha=0.4)
ax.arrow(
0.33, 0.75, 0.5, -0.5, linewidth=1, alpha=0.5, head_width=0.02, head_length=0.02
)
ax.text(0.45, 0.25, "$C_2C_1$", ha="center", va="center", alpha=0.4)
ax.arrow(
0.83,
-0.25,
-0.5,
-0.5,
linewidth=1,
alpha=0.5,
head_width=0.02,
head_length=0.02,
)
ax.text(0.9, -0.6, "$C_2O_1$", ha="center", va="center", alpha=0.4)
ax.arrow(
0.22, -0.75, 0.5, 0.5, linewidth=1, alpha=0.5, head_width=0.02, head_length=0.02
)
ax.text(0.45, -0.25, "$O_1C_2$", ha="center", va="center", alpha=0.4)
ax.arrow(
-0.33,
-0.75,
-0.5,
0.5,
linewidth=1,
alpha=0.5,
head_width=0.02,
head_length=0.02,
)
ax.text(-0.9, -0.6, "$O_1I_1$", ha="center", va="center", alpha=0.4)
ax.arrow(1.5, 0.03, 1, 0, linewidth=1, alpha=0.5, head_width=0.02, head_length=0.02)
ax.text(2, 0.2, "$C_20_2$", ha="center", va="center", alpha=0.4)
ax.arrow(
2.5, -0.03, -1, 0, linewidth=1, alpha=0.5, head_width=0.02, head_length=0.02
)
ax.text(2, -0.2, "$O_2C_2$", ha="center", va="center", alpha=0.4)
# make axes to disappear
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title)
ax.axis("off")
################################
## usefull methods on results ##
################################
[docs]
def default_PP(results: axon_results) -> axon_results:
"""
Default postprocessing function. Rasterize data and remove v_mem to alliviate RAM usage
Parameters
----------
results : axon_results
results of the axon simulation.
Returns
-------
axon_results
updated results of the axon simulation.
"""
results.rasterize()
results.remove_key("V_mem")
return results
[docs]
def rmv_keys(
results: axon_results,
keys_to_remove: str | set[str] = {},
keys_to_keep: set[str] = {},
) -> axon_results:
"""
remove most of the results key to save computing memory.
Note
----
Only the following keys kept by the fuction
- `ID`
- `L`
- `V_mem_raster_position`
- `V_mem_raster_x_position`
- `V_mem_raster_time_index`
- `V_mem_raster_time`
- `myelinated`
- `intra_stim_starts`
- `intra_stim_position`
Warning
-------
This function removes a large amount of data from the `results` object. Some methods included in the results may not be used afterwards
Parameters
----------
results : axon_results
results of the axon simulation.
keys_to_remove : str | list[str], optional
key or set of key that should be removed, by default []
keys_to_keep : str | list[str], optional
If None only keys_to_remove are removed. Otherwise, all key exept those in this list are deleted, by default None.
Returns
-------
axon_results
updated results.
"""
results.rasterize("V_mem")
default_list_keys = {
"ID",
"L",
"V_mem_raster_position",
"V_mem_raster_x_position",
"V_mem_raster_time_index",
"V_mem_raster_time",
"myelinated",
"intra_stim_starts",
"intra_stim_position",
"recorder",
}
default_list_keys.update(set(keys_to_keep))
default_list_keys -= set(keys_to_remove)
results.remove_key(keys_to_keep=default_list_keys)
return results
[docs]
def is_recruited(results: axon_results) -> axon_results:
"""
Evaluate if each fibre is recruited by a stimulation (see `axon_results.is_recruited`) and remove most of the `axon_results` keys to alliviate RAM usage.
Note
----
Only the following keys kept by the fuction
- `ID`
- `L`
- `V_mem_raster_position`
- `V_mem_raster_x_position`
- `V_mem_raster_time_index`
- `V_mem_raster_time`
- `myelinated`
- `y`
- `z`
- `diameter`
- `intra_stim_starts`
- `tstop`
- `intra_stim_position`
- `extracellular_electrode_x`
- `recruited`
Parameters
----------
results : axon_results
results of the axon simulation.
Returns
-------
axon_results
updated results.
"""
results.is_recruited(vm_key="V_mem")
# remove non nevessary data
list_keys = {
"ID",
"L",
"V_mem_raster_position",
"V_mem_raster_x_position",
"V_mem_raster_time_index",
"V_mem_raster_time",
"myelinated",
"y",
"z",
"diameter",
"intra_stim_starts",
"tstop",
"intra_stim_position",
"extracellular_electrode_x",
"recruited",
}
results.remove_key(keys_to_keep=list_keys)
return results
[docs]
def is_blocked(
results: axon_results,
AP_start: float | None = None,
freq: float | None = None,
t_refractory: float = 1,
) -> axon_results:
"""
Evaluate the impact od a blocking stimulation on axon (see `axon_results.block_summary`) and remove most of the `axon_results` keys to alliviate RAM usage.
Note
----
Only the following keys kept by the fuction
- `ID`
- `L`
- `V_mem_raster_position`
- `V_mem_raster_x_position`
- `V_mem_raster_time_index`
- `V_mem_raster_time`
- `myelinated`
- `y`
- `z`
- `diameter`
- `intra_stim_starts`
- `tstop`
- `intra_stim_position`
- `extracellular_electrode_x`
- `blocked`
- `has_onset`
- `n_onset`
Parameters
----------
results : axon_results
results of the axon simulation.
save : bool, optional
if true, the block status is saved as a line in a `.csv` file, by default False
fdir : str, optional
Path where the conductivity should be saved, the saving file will be `fdir+block_summary.csv`, by default ""
Note
----
If `save` is true, the data arranged as column of a same line in the following order:
- `ID`
- `y`
- `z`
- `diameter`
- `myelinated`
- `is_blocked`
- `has_onset`
- `n_onset`
Returns
-------
axon_results
updated results.
"""
## TO CHANGE WHEN is block is developped
# results.axon_state(save=False)
if AP_start is None:
if "intra_stim_starts" in results and results["intra_stim_starts"] != []:
AP_start = results["intra_stim_starts"][0]
vm_key = "V_mem"
if freq is not None:
vm_key += "_filtered"
results.block_summary(AP_start=AP_start, freq=freq, t_refractory=t_refractory)
# remove non nevessary data
list_keys = {
"ID",
"L",
f"{vm_key}_raster_position",
f"{vm_key}_raster_x_position",
f"{vm_key}_raster_time_index",
f"{vm_key}_raster_time",
"myelinated",
"y",
"z",
"diameter",
"intra_stim_starts",
"tstop",
"intra_stim_position",
"extracellular_electrode_x",
"is_blocked",
"has_onset",
"n_onset",
}
results.remove_key(keys_to_keep=list_keys)
return results
[docs]
def sample_keys(
results: axon_results,
keys_to_sample: str | set[str] = {},
t_start_rec: float = 0,
t_stop_rec: float = -1,
sample_dt: None | float = None,
i_sampled_t: None | np.ndarray = None,
x_bounds: None | float | tuple[float] = None,
keys_to_remove: str | set[str] = {},
keys_to_keep: set[str] = {},
) -> axon_results:
"""
Undersample the membrane coductivity (``results["g_mem"]``) key and remove most of the `axon_results` keys to alliviate RAM usage.
Note
----
Only the following keys kept by the fuction
- `g_mem`
- `x_rec`
- `rec`
- `Nseg_per_sec`
- `axon_path_type`
- `t_sim`
Parameters
----------
results : axon_results
results of the axon simulation.
t_start_rec : float, optional
Lower time at whitch `g_mem` should be stored, by default 0
t_stop_rec : float, optional
Upper time at whitch `g_mem` should be stored, by default -1
sample_dt : None | float, optional
Time sample rate at which `g_mem` should be stored if None simulation dt is kept, by default None
x_bounds : None | tuple[float], optional
x-positions where to store `g_mem`, possible option:
- float: The values of `g_mem` are only stored at the nearest position in `x_rec`.
- tupple: `g_mem` values are stored for all positions included between the two boundaries.
- None (default): `g_mem` values are stored for all positions.
Warning
-------
``sample_dt`` should be at multiple of the simulation ``dt`` to allow a correct undersampling.
If the not ``sample_dt`` choosen will be the closer multiple of ``dt``.
Returns
-------
axon_results
updated results.
"""
if isinstance(keys_to_sample, str):
keys_to_sample = {keys_to_sample}
if len(set(keys_to_sample) - set(results.keys())):
rise_error(
set(keys_to_sample) - set(results.keys()),
"keys are missing to apply postprocessing. Please check simulation parameters",
)
else:
# x - sampling array
if x_bounds is None:
I_x = np.arange(len(results["x_rec"]))
x_bounds = (0, results["x_rec"][-1])
elif np.iterable(x_bounds):
I_x = np.argwhere(
(results["x_rec"] > x_bounds[0]) & (results["x_rec"] < x_bounds[1])
)[:, 0]
else:
x_bounds = [x_bounds]
I_x = np.array([np.argmin(abs(results["x_rec"] - x_bounds[0]))])
N_x = len(I_x)
# t - sampling array
if i_sampled_t is not None:
t_APs = i_sampled_t
else:
if t_stop_rec < 0:
i_t_max = len(results["t"])
else:
i_t_max = np.argwhere(results["t"] <= t_stop_rec)[-1][0]
if sample_dt is None:
sample_dt = results.dt
i_t_min = np.argwhere(results["t"] >= t_start_rec)[0][0]
t_APs = [k for k in range(i_t_min, i_t_max)]
t_APs = t_APs[:: int(sample_dt / results.dt)]
# Under sampling to reduce memory consumption
results["x_rec"] = results["x_rec"][I_x] - x_bounds[0]
if "t" in results:
results["t"] = results["t"][t_APs]
for key in keys_to_sample:
results[key] = results[key][np.ix_(I_x, t_APs)]
###############################
## remove non nevessary data ##
###############################
list_keys = {
"ID",
"model",
"x_rec",
"rec",
"Nseg_per_sec",
"axon_path_type",
"t_sim",
"myelinated",
"intra_stim_starts",
"intra_stim_position",
"recorder",
}
list_keys.update(keys_to_keep)
list_keys.update(keys_to_sample)
if results.ID == 0:
list_keys.update({"t"})
results.remove_key(keys_to_keep=list_keys, keys_to_remove=keys_to_remove)
return results
[docs]
def sample_g_mem(
results: axon_results,
t_start_rec: float = 0,
t_stop_rec: float = -1,
sample_dt: None | float = None,
x_bounds: None | float | tuple[float] = None,
) -> axon_results:
"""
Undersample the membrane coductivity (``results["g_mem"]``) key and remove most of the `axon_results` keys to alliviate RAM usage.
Note
----
Only the following keys kept by the fuction
- `g_mem`
- `x_rec`
- `rec`
- `Nseg_per_sec`
- `axon_path_type`
- `t_sim`
Parameters
----------
results : axon_results
results of the axon simulation.
t_start_rec : float, optional
Lower time at whitch `g_mem` should be stored, by default 0
t_stop_rec : float, optional
Upper time at whitch `g_mem` should be stored, by default -1
sample_dt : None | float, optional
Time sample rate at which `g_mem` should be stored if None simulation dt is kept, by default None
x_bounds : None | tuple[float], optional
x-positions where to store `g_mem`, possible option:
- float: The values of `g_mem` are only stored at the nearest position in `x_rec`.
- tupple: `g_mem` values are stored for all positions included between the two boundaries.
- None (default): `g_mem` values are stored for all positions.
Warning
-------
``sample_dt`` should be at multiple of the simulation ``dt`` to allow a correct undersampling.
If the not ``sample_dt`` choosen will be the closer multiple of ``dt``.
Returns
-------
axon_results
updated results.
"""
if not results["record_g_mem"]:
rise_error("gmem not recorded nothing will be done")
else:
if t_stop_rec < 0:
t_stop_rec = results.t_sim
if sample_dt is None:
sample_dt = results.dt
if x_bounds is None:
x_bounds = (0, results.L)
if np.iterable(x_bounds):
I_x = np.argwhere(
(results["x_rec"] > x_bounds[0]) & (results["x_rec"] < x_bounds[1])
)[:, 0]
else:
x_bounds = [x_bounds]
I_x = np.array([np.argmin(abs(results["x_rec"] - x_bounds[0]))])
N_x = len(I_x)
i_t_min = np.argwhere(results["t"] > t_start_rec)[0][0]
i_t_max = np.argwhere(results["t"] < t_stop_rec)[-1][0]
t_APs = [k for k in range(i_t_min, i_t_max)]
t_APs = t_APs[:: int(sample_dt / results.dt)]
N_t = len(t_APs)
# Under sampling to reduce memory consumption
results["x_rec"] = results["x_rec"][I_x] - x_bounds[0]
results["t"] = results["t"][t_APs]
results["g_mem"] = results["g_mem"][np.ix_(I_x, t_APs)]
###############################
## remove non nevessary data ##
###############################
list_keys = {
"g_mem",
"x_rec",
"rec",
"Nseg_per_sec",
"axon_path_type",
"t_sim",
}
if results.ID == 0:
list_keys.update({"t"})
results.remove_key(keys_to_keep=list_keys)
return results
[docs]
def vmem_plot(
results: axon_results, freq: float = None, save: bool = False, fdir: str = ""
):
"""
Plot and save the membrane potential along each axon of the fascicle
Parameters
----------
results : axon_results
results of the axon simulation.
freq : float, optional
Filter vmem at the specified frequency, by default None
save : bool, optional
if true, the block status is saved as a line in a `.csv` file, by default False
fdir : str, optional
Path where the conductivity should be saved, the saving file will be `fdir+"block_summary.csv"`, by default ""
Returns
-------
axon_results
updated results.
"""
vm_key = "V_mem"
if freq is not None:
results.filter_freq("V_mem", freq, Q=2)
vm_key += "_filtered"
fig, ax = plt.subplots()
results.plot_x_t(ax, key=vm_key)
if results.myelinated:
title = f"Myelinated Axon: {np.round(results.diameter,2)} µm in diameter"
else: # colormap for unmyelinated fibers as plot_x_t are unreadable
title = f"Unmyelinated Axon: {np.round(results.diameter,2)} µm in diameter"
# print(fdir)
ax.set_title(title)
if save:
fig.tight_layout()
fig_name = fdir + "/Activity_axon_" + str(results.ID) + ".png"
fig.savefig(fig_name)
plt.close(fig)
return results
[docs]
def raster_plot(
results: axon_results, freq: float = None, save: bool = False, fdir: str = ""
):
"""
Plot and save the raster plot along each axon of the fascicle.
Parameters
----------
results : axon_results
results of the axon simulation.
freq : float, optional
Filter vmem at the specified frequency, by default None
save : bool, optional
if true, the block status is saved as a line in a `.csv` file, by default False
fdir : str, optional
Path where the conductivity should be saved, the saving file will be `fdir+block_summary.csv`, by default ""
Returns
-------
axon_results
updated results.
"""
vm_key = "V_mem"
if freq is not None:
results.filter_freq("V_mem", freq, Q=2)
vm_key += "_filtered"
fig, ax = plt.subplots()
results.raster_plot(ax, key=vm_key)
if results.myelinated:
title = f"Myelinated Axon: {np.round(results.diameter,2)} µm in diameter"
else:
title = f"Unyelinated Axon: {np.round(results.diameter,2)} µm in diameter"
# print(fdir)
ax.set_title(title)
if save:
fig.tight_layout()
fig_name = fdir + "/Rasterplot_axon_" + str(results.ID) + ".png"
fig.savefig(fig_name)
plt.close(fig)
return results