Source code for nrv.optim.optim_utils._OptimFunctions

import numpy as np
import matplotlib.pyplot as plt
import csv

from ...backend._file_handler import rmv_ext
from ...backend._log_interface import rise_warning
from ...utils._nrv_function import nrv_interp


####################################################################
################# generate waveform functions ######################
####################################################################


## interpolation
[docs] def interpolate( y: np.ndarray, x: np.ndarray = [], scale=4, intertype="Spline", bounds=(0, 0), save=False, filename="interpolate.dat", save_scale=False, kwargs_interp={}, ): """ :meta private: Parameters ---------- y : np.ndarray _description_ x : np.ndarray, optional _description_, by default [] scale : int, optional _description_, by default 4 intertype : str, optional _description_, by default "Spline" bounds : tuple, optional _description_, by default (0, 0) save : bool, optional _description_, by default False filename : str, optional _description_, by default "interpolate.dat" save_scale : bool, optional _description_, by default False kwargs_interp : dict, optional _description_, by default {} Returns ------- _type_ _description_ """ y_scale = [] if len(x) == 0: x = [i + 1 for i in range(len(y))] if type(scale) == float or type(scale) == int: x_scale = np.linspace(x[0], x[-1], int(scale)) elif len(scale) >= 1: x_scale = scale else: x_scale = np.linspace(x[0], x[-1], (1 / 100) * (len(x) - 1) + 1) if intertype.lower() == "spline": intertype = "catmull-rom" y_interpol = nrv_interp(x, y, kind=intertype.lower(), **kwargs_interp) y_scale = y_interpol(x_scale) if bounds[0] != bounds[1]: for i in range(len(y_scale)): if y_scale[i] < min(bounds): y_scale[i] = min(bounds) if y_scale[i] > max(bounds): y_scale[i] = max(bounds) if save: np.set_printoptions(threshold=40000) if save_scale and filename[-3:] == "csv": np.savetxt( filename, np.transpose(np.array([x_scale, y_scale])), delimiter="," ) else: y_str = np.array2string(y_scale, separator="\n") file = open(filename, "w") file.write(y_str[1:-2]) file.close() np.set_printoptions(threshold=10) return y_scale
[docs] def interpolate_amp( position: np.ndarray, t_sim: float = 100, t_end: float = None, dt: float = 0.005, intertype: str = "Spline", bounds: tuple[float] = (0, 0), save: bool = False, filename: str = "interpolate_part.dat", save_scale: bool = False, ) -> np.ndarray: """ genarte a waveform from a particle position using interpolate where the position values are the output waveform amplitudes at constant sample rate Parameters ---------- position : array particle position in n dimension output waveform amplitudes at regular times t_sim : float simulation time (ms), by default 100 dt : float time step of the simulation (ms), by default 0.005 intertype : str type of interpolation perform, by default 'Spline' type possibly: - 'Spline' : Cubic spline interpolation bounds : tupple limit range of the interpolation, if both equal no limit,by default (0,0) save : bool save or not the output in a .dat file, by default False filename : str name of the file on wich the output should be saved, by default 'interpolate_part.dat' Returns ------- waveform : np.ndarray result of the interpolation """ if t_end is None: t_end = t_sim dim = len(position) time_particle = np.linspace(0, t_end, dim) scale = int(t_end / dt) waveform = interpolate( position, x=time_particle, scale=scale, intertype=intertype, bounds=bounds, save=save, filename=filename, save_scale=save_scale, ) if t_end < t_sim: dif = int((t_sim - t_end) / dt) waveform = np.concatenate((waveform, np.ones(dif))) elif t_end > t_sim: waveform = waveform[: int(t_sim / dt) + 1] return waveform
[docs] def interpolate_Npts( position, t_sim: float = 100, dt: float = 0.005, amp_start: float = 0, amp_stop: float = 1, intertype: str = "Spline", bounds: tuple[float] = (0, 0), fixed_order: bool = False, t_end: float = None, t_shift: float = None, save: bool = False, fname: str = "interpolate_2pts.dat", plot: bool = False, save_scale: bool = False, generatefigure: bool = True, strict_bounds: bool = True, kwargs_interp: dict = {}, **kwargs ): r""" genarte a waveform from a particle position using interpolate where the position values are the coordonnate of N points which should be reached by the output waveform Note ---- If :math:`t_{i}` and :math:`I_{i}` are the time and amplitude of the :math:`ith` point the position vector :math:`\mathcal{X}` should be: .. math:: \mathcal{X} = \begin{pmatrix} t_{1} & I_{1} & t_{2} & I_{2} & ... & t_{N} & I_{N} \end{pmatrix} Parameters ---------- position : np.ndarray particle position in 2N dimensions with the coordonnate of the N points to interpolate. t_sim : float simulation time (ms), by default 100 dt : float time step of the simulation (ms), by default 0.005 amp_start : float amplitude at the beginning of the interpolation amp_stop : float amplitude at the end of the interpolation intertype : str type of interpolation perform, by default 'Spline' type possibly: - 'Spline' : Cubic spline interpolation - 'linear' bounds : tupple limit range of the interpolation, if both equal no limit, by default (0,0) fixed_order :bool fix the order of the points to interpolate t_end :float optionnal, if not None, time of the stimulation at which the interpollation should reach amp_stop t_shift :float optionnal, if not None, interpolation will be shifted of this time save : bool save or not the output in a .dat file, by default False fname : str name of the file on wich the output should be saved, by default 'interpolate_2pts.dat' strict_bounds :bool if True values out of bound will be set to closer bound kwargs_interp : dict kwargs to add to the interpollation Returns ------- waveform : np.ndarray waveform generated from position """ n_dim = len(position) n_pts = n_dim // 2 X = np.array( [[position[2 * k], position[2 * k + 1]] for k in range(n_pts)], dtype=float ) # if odd number of dimention the last scalar is use to set t_end if n_dim % 2 == 1: t_end = max(position[-1], 2 * n_pts * dt) X[:, 0] *= t_end elif t_end is None: t_end = t_sim for i in range(n_pts): if X[i, 0] < dt: X[i, 0] += dt elif X[i, 0] > t_end - (n_pts * dt): X[i, 0] -= n_pts * dt if fixed_order: rise_warning(NotImplemented, "fixed_order not NotImplemented") fixed_order = False if fixed_order: pass else: I = np.argsort(X[:, 0]) X = X[I] for i in range(n_pts): for j in range(i + 1, n_pts): if X[i, 0] + dt > X[j, 0]: X[j, 0] = X[j, 0] + dt t = np.concatenate([[0], X[:, 0], [t_end]]) x = np.concatenate([[amp_start], X[:, 1], [amp_stop]]) I = np.argsort(t) t = t[I] x = x[I] if strict_bounds: bds = bounds else: bds = (0, 0) waveform = interpolate( y=x, x=t, scale=int(t_end / dt) + 1, intertype=intertype, bounds=bds, save=False, save_scale=save_scale, **kwargs_interp ) if t_shift is not None: dif = int(t_shift / dt) waveform = np.concatenate((amp_start * np.ones(dif), waveform)) t_end += t_shift if t_end < t_sim: dif = int((t_sim - t_end) / dt) waveform = np.concatenate((waveform, amp_stop * np.ones(dif))) elif t_end > t_sim: waveform = waveform[: int(t_sim / dt) + 1] if save or plot: T = np.linspace(0, t_sim, len(waveform)) if generatefigure: plt.figure() plt.plot(T, waveform) plt.scatter(t, x) if save: plt.savefig(fname) return waveform
#################################################################### ########################### savers ################################# ####################################################################
[docs] def cost_position_saver(data, file_name="document.csv"): """ Simple saver which can be used in a cost_function to save the cost and position in a .csv file (see .Optim.cost_function) Parameters ---------- data : dict dict containing the keys 'cost' and 'position' file_name: name of the saving file. NB: if missing, extension ".csv" will be add at the end of the file """ save = [str(data["cost"])] position = data["position"] fname = rmv_ext(file_name) + ".csv" dim = len(position) for i in range(dim): save += [position[i]] with open(fname, "a", newline="") as fd: writer = csv.writer(fd) writer.writerow(save)