Source code for nrv.optim.optim_utils._OptimResults

import matplotlib.pyplot as plt
from .._CostFunctions import cost_function
import numpy as np

from ...backend._NRV_Results import NRV_results
from ...backend._log_interface import pass_info, rise_warning


[docs] class optim_results(NRV_results): """ Container class storing optimization parameters, history, and best solution data. """
[docs] def __init__(self, context=None): """ Initialize an optimization-results container. Parameters ---------- context : Any, optional Optimization-parameter context stored under ``optimization_parameters``. """ super().__init__({"optimization_parameters": context})
[docs] def load(self, data, blacklist=[], **kwargs): """ Load optimization results and reserialize the nested optimization parameters. """ super().load(data, blacklist, **kwargs) self["optimization_parameters"] = self["optimization_parameters"].save( save=False )
############################ ## Processing methods ## ############################ ## PSO relative methods
[docs] def is_stabilized(self, part, it, threshold=0.1): """ Check whether one particle velocity has fallen below a stabilization threshold. Parameters ---------- part : int Particle identifier. it : int Iteration index. threshold : float, optional Absolute velocity threshold. Returns ------- bool ``True`` when all velocity components are below the threshold. """ velocity = self["velocity" + str(part)][it] for i in velocity: if abs(i) >= threshold: return False return True
[docs] def stabilization_it(self, parts=None, nit=None, threshold=None): """ Estimate the last iteration at which selected particles were not yet stabilized. Parameters ---------- parts : int | iterable | None, optional Particle identifier, iterable of identifiers, or ``None`` for the whole swarm. nit : int | None, optional Number of iterations to inspect. threshold : float | None, optional Stabilization threshold. If omitted, derive it from particle 1 velocity history. Returns ------- int | np.ndarray | None Stabilization iteration index or indices. """ if threshold == None: threshold = max(max(self["velocity1"])) / 100 pass_info("No threshold in parameters, set to Vpart1max/100 = ", threshold) # One particle if type(parts) == int: if nit is None: nit = self.nit for i in range(nit - 1, 0, -1): if not self.is_stabilized(parts, i, threshold): return i return None # A particle list elif np.iterable(parts): list_it = [] for i in parts: list_it += [ self.stabilization_it(parts=int(i), nit=nit, threshold=threshold) ] return np.array(list_it) # The whole swram else: swarm = 1 + np.arange(self["optimization_parameters"]["n_particles"]) return self.stabilization_it(swarm, nit=nit, threshold=threshold)
[docs] def add_filter(self, part_filter): """ Apply a post-processing filter to every stored particle position. Parameters ---------- part_filter : callable Filter applied to each position vector. Returns ------- optim_results Updated results object. """ n_particles = self["optimization_parameters"]["n_particles"] nit = self.nit resultsf = self for i in range(n_particles): resultsf["position" + str(i + 1) + " filtered"] = [[] for j in range(nit)] for j in range(nit): part = resultsf["position" + str(i + 1)][j] filtered_part = part_filter(part) resultsf["position" + str(i + 1) + " filtered"][j] = filtered_part return resultsf
[docs] def findbestpart(self, decimals=10, verbose=False, lim_it=None): """ Find which particle reached the recorded best position. Parameters ---------- decimals : int, optional Decimal precision used for approximate comparison. verbose : bool, optional If ``True``, print the search progress. lim_it : int | None, optional Maximum iteration index inspected. Returns ------- int Particle identifier, or ``-1`` when the search fails. """ if decimals < -15: rise_warning("Best results not founded returning -1") return -1 n_particles = self["optimization_parameters"]["n_particles"] nit = self.nit bestpos = self["best_position"] ibestpart = 0 if lim_it is None: lim_it = nit for j in range(lim_it): for i in range(1, 1 + n_particles): pos = self["position" + str(i)] if all(np.around(pos[j], decimals) == np.around(bestpos, decimals)): if verbose: print("it=", j) return i if verbose: pass_info("not found with decimals =", decimals) return self.findbestpart(decimals - 1, verbose=verbose, lim_it=lim_it)
[docs] def compute_best_pos(self, cost_function: cost_function, **kwrgs): """ Recompute simulation results at the recorded best position. Parameters ---------- cost_function : cost_function Cost-function object able to simulate the corresponding context. **kwrgs : dict Reserved for future use. Returns ------- sim_results Simulation results at the best position. """ return cost_function.get_sim_results(self.x)
############################ ## plotting methods ## ############################
[docs] def plot_cost_history( self, ax: plt.axes, nitstop: int = -1, xlog: bool = False, ylog: bool = False, **ax_kwargs, ): """ Plot the optimization cost history. Parameters ---------- ax : matplotlib.axes.Axes Target axes. nitstop : int, optional Final iteration displayed. xlog : bool, optional If ``True``, use a logarithmic x-axis. ylog : bool, optional If ``True``, use a logarithmic y-axis. **ax_kwargs : dict Additional plotting keyword arguments. """ cost = self["cost_history"] ax.plot(cost[0:nitstop], **ax_kwargs) if xlog: ax.set_xscale("log") if ylog: ax.set_yscale("log")