Source code for nrv.optim.optim_utils._CostEvaluation
import numpy as np
from ...utils._stimulus import stimulus, set_common_time_series
from ...fmod._extracellular import extracellular_context
from ...backend._NRV_Class import NRV_class, load_any, abstractmethod
from ...backend._NRV_Simulable import sim_results
from ...nmod.results._axons_results import axon_results
from ...nmod.results._fascicles_results import fascicle_results
from ...nmod.results._nerve_results import nerve_results
from ...utils._nrv_function import cost_evaluation
from ...ui._axon_postprocessing import *
[docs]
class raster_count_CE(cost_evaluation):
"""
Create a callable object which returne the number of spike from the result
of a simulation
"""
[docs]
def __init__(self):
super().__init__()
[docs]
def call_method(self, results: sim_results, **kwargs) -> float:
"""
Returns the spike number from a simulation result
"""
if "V_mem_raster_position" not in results:
results.rasterize("V_mem")
pos = results["V_mem_raster_position"]
M = len(results["x_rec"]) - 1 # pos starts at 0
i_first_pos = np.where(pos == 0)
i_last_pos = np.where(pos == M)
cost = (len(i_first_pos[0]) + len(i_last_pos[0])) / 2
return cost
[docs]
class recrutement_count_CE(cost_evaluation):
r"""
Callable object which returns the number of triggered fibre in the results
Parameters
----------
reverse : bool
if True, the final cost is the difference between the number total of fibre and the number of activate fibre
Note
----
if reverse is false:
.. math::
cost = N_{recruited}
else:
.. math::
cost = N_{total} - N_{recruited}
"""
[docs]
def __init__(self, reverse=False):
super().__init__()
self.reverse = reverse
[docs]
def count_axon_activation(self, results: sim_results):
cpt = results.is_recruited()
if self.reverse:
cpt = not cpt
return int(cpt)
[docs]
def count_fascicle_activation(self, results: sim_results):
cpt = 0
for i in range(len(results["axons_diameter"])):
if self.reverse:
cpt += 1 - results["axon" + str(i)]["spike"]
else:
cpt += results["axon" + str(i)]["spike"]
return cpt
[docs]
def call_method(self, results: sim_results, **kwargs) -> float:
"""
Returns the spike number from a simulation result
Parameters
----------
results : dict
output of an axon simulation using Markov model for at least a node
Returns
-------
cost :int
number of spike in the v_mem part
"""
cost = 0
if isinstance(results, axon_results):
cost = self.count_axon_activation(results)
else:
cost = results.get_recruited_axons(ax_type="all", normalize=False)
if self.reverse:
cost = results.n_ax - cost
return cost
[docs]
class charge_quantity_CE(cost_evaluation):
r"""
Create a callable object which return a value proportionnal to the charge quantity injected by stimulus.
.. math::
cost = \sum_{e}\sum_{t_k}{i_{e,stim}(t_k)}
with :math:`t_k` is the discrete time step of the simulation
"""
[docs]
def __init__(self, id_elec=None, dt_res=0.0001):
super().__init__()
self.id_elec = id_elec
self.dt_res = dt_res
[docs]
def compute_stimulus_cost(self, stim: stimulus):
stim_ = stimulus()
t_min, t_max = stim.t[0], stim.t[-1]
N_pts = int((t_max - t_min) // self.dt_res)
stim_.t = np.linspace(t_min, t_max, N_pts)
set_common_time_series(stim, stim_)
return abs(stim).integrate()
[docs]
def call_method(self, results: sim_results, **kwargs) -> float:
extra_stim = load_any(results["extra_stim"])
N_elec = len(extra_stim.stimuli)
cost = 0
if self.id_elec is None:
self.id_elec = [k for k in range(N_elec)]
elif isinstance(self.id_elec, int):
self.id_elec = [self.id_elec]
for i in self.id_elec:
cost += self.compute_stimulus_cost(extra_stim.stimuli[i])
return cost
[docs]
class stim_energy_CE(cost_evaluation):
r"""
Create a callable object which return a value proportionnal to the stimulus energy, assuming the electrode impedance is a constant.
.. math::
cost = \sum_{e}\sum_{t_k}{i_{e,stim}^2(t_k)}
with :math:`t_k` is the discrete time step of the simulation
Parameters
----------
id_elec : None | int | list[int]
id or list id of the electrode of the to from which the energy should be computed. If None,
dt_res : float
resolotion time step use to compute the cost value
"""
[docs]
def __init__(self, id_elec: None | int | list[int] = None, dt_res: float = 0.0001):
super().__init__()
self.id_elec = id_elec
self.dt_res = dt_res
[docs]
def compute_stimulus_cost(self, stim: stimulus):
stim_ = stimulus()
t_min, t_max = stim.t[0], stim.t[-1]
N_pts = int((t_max - t_min) // self.dt_res)
stim_.t = np.linspace(t_min, t_max, N_pts)
set_common_time_series(stim, stim_)
stim.s = stim.s * stim.s
return abs(stim).integrate()
[docs]
def call_method(self, results: sim_results, **kwargs) -> float:
extra_stim = load_any(results["extra_stim"])
N_elec = len(extra_stim.stimuli)
cost = 0
if self.id_elec is None:
self.id_elec = [k for k in range(N_elec)]
elif isinstance(self.id_elec, int):
self.id_elec = [self.id_elec]
for i in self.id_elec:
cost += self.compute_stimulus_cost(extra_stim.stimuli[i])
return cost