"""
NRV-:class:`.NRV_results` handling.
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
from collections.abc import Iterable
from scipy import signal
from ._NRV_Class import NRV_class, load_any, abstractmethod, is_NRV_class
from ._log_interface import rise_warning, pass_info
from ..utils._stimulus import stimulus
from ._file_handler import json_load
[docs]
def generate_results(obj: any, **kwargs):
"""
generate the proper results object depending of the obj simulated
Parameters
----------
obj : any
"""
nrv_obj = load_any(obj)
if "nrv_type" in nrv_obj.__dict__:
nrv_type = nrv_obj.nrv_type
return eval('sys.modules["nrv"].' + nrv_type + "_results")(context=obj)
[docs]
class NRV_results(NRV_class, dict):
"""
Results class for NRV
"""
[docs]
@abstractmethod
def __init__(self, context=None):
"""
Initialize a results container from a saved context.
Parameters
----------
context : dict | NRV_class | None, optional
Initial context used to populate the results object. If an
:class:`~nrv.backend._NRV_Class.NRV_class` instance is provided,
its serialized state is used.
"""
super().__init__()
# Not ideal but require to load method
self.np_keys = {}
if context is None:
context = {}
elif is_NRV_class(context):
context.save(save=False)
if "nrv_type" in context:
context["result_type"] = context.pop("nrv_type")
# Discard saving for empty results (mostly fo Mcore)
self.update(context)
self.__sync()
@property
def to_save(self):
"""
Indicate whether the results object should be serialized.
Returns
-------
bool
``True`` if the results are not marked as dummy data.
"""
return "dummy_res" not in self
@property
def is_dummy(self):
"""
Indicate whether the results object is a dummy placeholder.
Returns
-------
bool
``True`` if the ``dummy_res`` marker is present.
"""
return "dummy_res" in self
[docs]
def save(self, save=False, fname="nrv_save.json", blacklist=[], **kwargs):
"""
Save the results object after synchronizing its internal state.
Parameters
----------
save : bool, optional
If ``True``, write the serialized content to disk.
fname : str, optional
Output filename used when ``save`` is ``True``.
blacklist : list, optional
Keys excluded from serialization.
Returns
-------
dict
Serialized representation returned by :meth:`NRV_class.save`.
"""
save = save and self.to_save
self.__update_np_keys()
self.__sync()
return super().save(save, fname, blacklist, **kwargs)
[docs]
def load(self, data, blacklist=[], **kwargs):
"""
Load the results object from serialized data.
Parameters
----------
data : str | dict
JSON filename or serialized dictionary.
blacklist : list, optional
Keys excluded from loading.
"""
if isinstance(data, str):
key_dic = json_load(data)
else:
key_dic = data
for key, item in key_dic.items():
if key in key_dic["np_keys"]:
self.__dict__[key] = np.array(
[], dtype=np.dtype(key_dic["np_keys"][key])
)
elif key not in self.__dict__:
self.__dict__[key] = item
super().load(data, blacklist, **kwargs)
self.__sync()
def __setitem__(self, key, value):
"""
Set a result value in both the mapping and the attribute namespace.
Parameters
----------
key : str
Result key to set.
value : any
Value associated with the key.
"""
if not key == "nrv_type":
self.__dict__[key] = value
super().__setitem__(key, value)
def __delitem__(self, key):
"""
Delete a result value from both the mapping and the attribute namespace.
Parameters
----------
key : str
Result key to delete.
"""
if key not in self.__dict__:
rise_warning(key, "not found cannot be deleted from results")
else:
if not key == "nrv_type":
del self.__dict__[key]
super().__delitem__(key)
[docs]
def remove_key(
self,
keys_to_remove: str | set[str] = [],
keys_to_keep: set[str] | None = None,
verbose: bool = False,
):
"""
Remove a key or a list of keys from the results
Parameters
----------
keys_to_remove : str | list[str], optional
key or set of key that should be removed, by default []
keys_to_keep : str | list[str], optional
If None only keys_to_remove are removed. Otherwise, all key exept those in this list are deleted, by default None
verbose : bool, optional
If True print a message informing the suppression, by default False
"""
if keys_to_keep is not None:
keys_to_remove = set(self.keys()) - set(keys_to_keep)
self.remove_key(keys_to_remove=keys_to_remove, verbose=verbose)
else:
if isinstance(keys_to_remove, str):
del self[keys_to_remove]
# pass_info(
# "removed the following key from results: ",
# keys_to_remove,
# verbose=verbose,
# )
else:
for key in keys_to_remove:
del self[key]
# pass_info(
# "removed the following key from results: ", key, verbose=verbose
# )
[docs]
def update(self, __m, **kwargs) -> None:
"""
overload of dict update method to update both attibute and items
"""
self.__dict__.update(__m, **kwargs)
super().update(__m, **kwargs)
def __update_np_keys(self):
""" """
self.np_keys = {}
for key in self:
if isinstance(self[key], np.ndarray):
self.np_keys[key] = self[key].dtype.name
@property
def is_empty(self):
"""
Indicate whether the results contain a declared result type.
Returns
-------
bool
Internal emptiness flag derived from ``result_type``.
"""
return "result_type" in self and not self["result_type"] is None
def __sync(self):
"""
Synchronize the dictionary content with the instance attributes.
"""
self.update(self.__dict__)
self.pop("__NRVObject__")
def __contains__(self, key: object) -> bool:
"""
Check whether one key or a set of keys is present in the results.
Parameters
----------
key : object
Key to test, or iterable of keys.
Returns
-------
bool
``True`` if the requested key or keys are present.
"""
if isinstance(key, list) or isinstance(key, set):
missing_keys = set(key) - set(self.keys())
return len(missing_keys) == 0
return super().__contains__(key)
[docs]
class sim_results(NRV_results):
"""
Results container specialized for simulation outputs.
"""
[docs]
def __init__(self, context=None):
"""
Initialize simulation results from a serialized context.
Parameters
----------
context : dict | NRV_class | None, optional
Initial simulation context used to populate the results object.
"""
super().__init__(context)
[docs]
def filter_freq(self, my_key, freq, Q=10):
"""
Basic Filtering of quantities. This function design a notch filter (scipy IIR-notch).
Adds an item to the specified dictionary, with the key termination '_filtered' concatenated to the original key.
Parameters
----------
key : str|list[str]
name of the key to filter
freq : float or array, list, np.array
frequecy or list of frequencies to filter in kHz, as time is defined in ms in NRV2.
If multiple frequencies, they are filtered sequencially, with as may filters as frequencies, in the specified order
Q : float
quality factor of the filter, by default set to 10
"""
if isinstance(freq, Iterable):
f0 = np.asarray(freq)
else:
f0 = freq
if self["dt"] == 0:
rise_warning(
"Warning: filtering aborted, variable time step used for differential equation solving"
)
return False
else:
fs = 1 / self["dt"]
if isinstance(f0, Iterable):
new_sig = np.zeros(self[my_key].shape)
for k in range(len(self[my_key])):
new_sig[k, :] = self[my_key][k]
offset = self[my_key][k][0]
new_sig[k, :] = new_sig[k, :] - offset
for f in f0:
b_notch, a_notch = signal.iirnotch(f, Q, fs)
new_sig[k, :] = signal.lfilter(
b_notch, a_notch, new_sig[k, :][k]
)
new_sig[k, :] = new_sig[k, :] + offset
else:
## NOTCH at the stimulation frequency
b_notch, a_notch = signal.iirnotch(f0, Q, fs)
new_sig = np.zeros(self[my_key].shape)
for k in range(len(self[my_key])):
offset = self[my_key][k][0]
new_sig[k, :] = (
signal.lfilter(b_notch, a_notch, self[my_key][k] - offset)
+ offset
)
self[my_key + "_filtered"] = new_sig
[docs]
def plot_stim(self, IDs=None, t_stop=None, N_pts=1000, ax=None, **fig_kwargs):
"""
Plot one or several stimulis of the simulation extra-cellular context
"""
if "extra_stim" not in self:
rise_warning("No extracellular stimulation to be plotted")
else:
if IDs is None:
IDs = [i for i in range(len(self["extra_stim"]["stimuli"]))]
elif not np.iterable(IDs):
IDs = [int(IDs)]
for i in IDs:
stim = load_any(self["extra_stim"]["stimuli"][i])
if t_stop is None:
t_stop = stim.t[-1]
stim2 = stimulus()
stim2.s = np.zeros(N_pts)
stim2.t = np.linspace(0, t_stop, N_pts)
stim2 += stim
if ax is None:
plt.plot(stim2.t, stim2.s, **fig_kwargs)
else:
ax.plot(stim2.t, stim2.s, **fig_kwargs)