Source code for straxen.corrections_services

"""Return corrections from corrections DB."""

import os
from functools import lru_cache
import warnings
import pytz
from typing import List

import numpy as np
import strax
import utilix
import straxen
from urllib.parse import urlparse, parse_qs

export, __all__ = strax.exporter()

corrections_w_file = [
    "mlp_model",
    "cnn_model",
    "gcn_model",
    "s2_xy_map_mlp",
    "s2_xy_map_cnn",
    "s2_xy_map_gcn",
    "s2_xy_map",
    "s1_xyz_map_mlp",
    "s1_xyz_map_cnn",
    "s1_xyz_map_gcn",
    "fdc_map_mlp",
    "fdc_map_cnn",
    "fdc_map_gcn",
    "s1_aft_xyz_map",
    "bayes_model",
]

single_value_corrections = [
    "elife_xenon1t",
    "elife",
    "baseline_samples_nv",
    "electron_drift_velocity",
    "electron_drift_time_gate",
    "relative_light_yield",
    "electron_diffusion_cte",
]

arrays_corrections = [
    "hit_thresholds_tpc",
    "hit_thresholds_he",
    "hit_thresholds_nv",
    "hit_thresholds_mv",
]

dict_corrections = ["se_gain", "rel_extraction_eff", "avg_se_gain"]

# needed because we pass these names as strax options which then get
# paired with the default reconstruction algorithm
# important for apply_cmt_version
posrec_corrections_basenames = ["s1_xyz_map", "fdc_map", "s2_xy_map"]


[docs]@export class CMTVersionError(Exception): pass
class CMTnanValueError(Exception): pass
[docs]@export class CorrectionsManagementServices: """A class that returns corrections Corrections are set of parameters to be applied in the analysis stage to remove detector effects. Information on the strax implementation can be found at https://github.com/AxFoundation/strax/blob/master/strax/corrections.py """ def __init__(self, username=None, password=None, mongo_url=None, is_nt=True): """ :param username: corrections DB username read the .xenon_config for the users "pymongo_user" has readonly permissions to the corrections DB the "CMT admin user" has r/w permission to corrections DB and read permission to runsDB :param password: DB password :param is_nt: bool if True we are looking at nT if False we are looking at 1T """ mongo_kwargs = { "url": mongo_url, "user": username, "password": password, "database": "corrections", } corrections_collection = utilix.rundb.xent_collection(**mongo_kwargs) # Do not delete the client! self.client = corrections_collection.database.client # Setup the interface self.interface = strax.CorrectionsInterface(self.client, database_name="corrections") self.is_nt = is_nt if self.is_nt: self.collection = self.client["xenonnt"]["runs"] else: self.collection = self.client["run"]["runs_new"] def __str__(self): return self.__repr__() def __repr__(self): return str(f'{"XENONnT " if self.is_nt else "XENON1T"}-Corrections_Management_Services')
[docs] def get_corrections_config(self, run_id, config_model=None): """Get context configuration for a given correction. :param run_id: run id from runDB :param config_model: configuration model (tuple type) :return: correction value(s) """ if not isinstance(config_model, (tuple, list)) or len(config_model) != 2: raise ValueError(f"config_model {config_model} must be a tuple of length 2") model_type, version = config_model if "to_pe_model" in model_type: return self.get_pmt_gains(run_id, model_type, version) elif ( model_type in single_value_corrections or model_type in arrays_corrections or model_type in dict_corrections ): return self._get_correction(run_id, model_type, version) elif model_type in corrections_w_file: return self.get_config_from_cmt(run_id, model_type, version) else: raise ValueError( f"{model_type} not found, currently these are " f"available {single_value_corrections}, {arrays_corrections} and " f"{corrections_w_file} " )
# entry for e.g. for super runs # cache results, this would help when looking at the same gains @lru_cache(maxsize=None) def _get_correction(self, run_id, correction, version): """Smart logic to get correction from DB. :param run_id: run id from runDB :param correction: correction's name, key word (str type) :param version: local version (str type) :return: correction value(s) """ when = self.get_start_time(run_id) try: values = [] # hack to workaround to group all pmts # because every pmt is its own dataframe...of course if correction in {"pmt", "n_veto", "mu_veto"}: # get lists of pmts df_global = self.interface.read( "global_xenonnt" if self.is_nt else "global_xenon1t" ) gains = df_global["global_ONLINE"][0] # global is where all pmts are grouped pmts = list(gains.keys()) for it_correction in pmts: # loop over all PMTs if correction in it_correction: df = self.interface.read_at(it_correction, when) if df[version].isnull().values.any(): raise CMTnanValueError( f"For {it_correction} there are NaN values, this means no" f" correction available for {run_id} in version {version}, please" " check e-logbook for more info " ) if version in "ONLINE": df = self.interface.interpolate(df, when, how="fill") else: df = self.interface.interpolate(df, when) values.append(df.loc[df.index == when, version].values[0]) else: df = self.interface.read_at(correction, when) if df[version].isnull().values.any(): raise CMTnanValueError( f"For {correction} there are NaN values, this means no correction available" f" for {run_id} in version {version}, please check e-logbook for more info " ) if ( correction in corrections_w_file or correction in arrays_corrections or version in "ONLINE" or correction in dict_corrections ): df = self.interface.interpolate(df, when, how="fill") else: df = self.interface.interpolate(df, when) values.append(df.loc[df.index == when, version].values[0]) corrections = np.asarray(values) except KeyError: if "global" in version: raise ValueError( f"User is not allowed to pass {version} global version are not allowed" ) raise ValueError( f"Version {version} not found for correction {correction}, please check" ) else: return corrections
[docs] def get_pmt_gains( self, run_id, model_type, version, cacheable_versions=("ONLINE",), gain_dtype=np.float32 ): """Smart logic to return pmt gains to PE values. :param run_id: run id from runDB :param model_type: to_pe_model (gain model) :param version: version :param cacheable_versions: versions that are allowed to be cached in ./resource_cache :param gain_dtype: dtype of the gains to be returned as array :return: array of pmt gains to PE values """ to_pe = None cache_name = None if "to_pe_model" in model_type: # Get the detector name based on the requested model_type # This also will be used to the cachable name convention # pmt == TPC, n_veto == n_veto's PMT, etc detector_names = { "to_pe_model": "pmt", "to_pe_model_nv": "n_veto", "to_pe_model_mv": "mu_veto", } target_detector = detector_names[model_type] if version in cacheable_versions: # Try to load from cache, if it does not exist it will be created below cache_name = cacheable_naming(run_id, model_type, version) try: to_pe = straxen.get_resource(cache_name, fmt="npy") except (ValueError, FileNotFoundError): pass if to_pe is None: to_pe = self._get_correction(run_id, target_detector, version) # be cautious with very early runs, check that not all are None if np.isnan(to_pe).any(): raise ValueError( "to_pe(PMT gains) values are NaN, no data available " f"for {run_id} in the gain model with version" ) else: raise ValueError(f"{model_type} not implemented for to_pe values") # Double check the dtype of the gains to_pe = np.array(to_pe, dtype=gain_dtype) # Double check that all the gains are found, None is not allowed # since strax processing does not handle this well. If a PMT is # off it's gain should be 0. if np.any(np.isnan(to_pe)): pmts_affected = np.argwhere(np.isnan(to_pe))[:, 0] raise GainsNotFoundError( f"Gains returned by CMT are None for PMT_i = {pmts_affected}. " "Cannot proceed with processing. Report to CMT-maintainers." ) if ( cache_name is not None and version in cacheable_versions and not os.path.exists(cache_name) ): # This is an array we can save since it's in the cacheable # versions but it has not been saved yet. Next time we need # it, we can get it from our cache. np.save(cache_name, to_pe, allow_pickle=False) return to_pe
[docs] def get_config_from_cmt(self, run_id, model_type, version="ONLINE"): """Smart logic to return NN weights file name to be downloader by straxen.MongoDownloader() :param run_id: run id from runDB :param model_type: model type and neural network type; model_mlp, or model_gcn or model_cnn :param version: version :param return: NN weights file name """ if model_type not in corrections_w_file: raise ValueError( f"{model_type} is not stored in CMT " f"please check, these are available {corrections_w_file}" ) file_name = self._get_correction(run_id, model_type, version) if not file_name: raise ValueError( f"You have the right option but could not find a file" f"Please contact CMT manager and yell at him" ) return file_name
[docs] def get_start_time(self, run_id): """Smart logic to return start time from runsDB. :param run_id: run id from runDB :return: run start time """ if self.is_nt: # xenonnt use int run_id = int(run_id) rundoc = self.collection.find_one( {"number" if self.is_nt else "name": run_id}, {"start": 1} ) if rundoc is None: raise ValueError(f"run_id = {run_id} not found") time = rundoc["start"] return time.replace(tzinfo=pytz.utc)
[docs] def get_local_versions(self, global_version) -> dict: """Returns a dict of local versions for a given global version. Use 'latest' to get newest version """ # check that 'global' is in the passed string. if global_version == "latest": # CMT appends columns to the global versions dataframe, # so taking last one is the latests global_version = self.global_versions[-1] if "global" not in global_version: warnings.warn( "'global' does not appear in the passed global version. Are you sure this right?" ) # CMT generates a global version, global version is just a set of local versions # With this we can do pretty easy bookkeping for offline contexts cmt_global = self.interface.read("global_xenonnt") if global_version not in cmt_global: avail_global_versions_string = "\n".join([f"\t\t{v}" for v in self.global_versions]) raise ValueError( f"Global version {global_version} not found! " f"Try one of these:\n{avail_global_versions_string}" ) # get local versions from CMT global version local_versions = cmt_global[global_version][0] # to make returned dictionary more manageable, we prune all the per-PMT corrections # first rename to more clear variable local_versions["to_pe_model"] = local_versions["pmt_000_gain_xenonnt"] local_versions["to_pe_model_nv"] = local_versions["n_veto_000_gain_xenonnt"] local_versions["to_pe_model_mv"] = local_versions["mu_veto_000_gain_xenonnt"] # drop the per-PMT corrections pruned_local_versions = { key: val for key, val in local_versions.items() if "_gain_xenonnt" not in key } return pruned_local_versions
@property def global_versions(self): return self.interface.read("global_xenonnt").columns.tolist()
def cacheable_naming(*args, fmt=".npy", base="./resource_cache/"): """Convert args to consistent naming convention for array to be cached.""" if not os.path.exists(base): try: os.mkdir(base) except (FileExistsError, PermissionError): pass for arg in args: if not isinstance(arg, str): raise TypeError(f"One or more args of {args} are not strings") return base + "_".join(args) + fmt class GainsNotFoundError(Exception): """Fatal error if a None value is returned by the corrections.""" def get_cmt_local_versions(global_version): cmt = CorrectionsManagementServices() return cmt.get_local_versions(global_version) def args_idx(x): """Get the idx of "?" in the string.""" return x.rfind("?") if "?" in x else None @strax.Context.add_method def apply_cmt_version(context: strax.Context, cmt_global_version: str) -> None: """Sets all the relevant correction variables. :param cmt_global_version: A specific CMT global version, or 'latest' to get the newest one """ local_versions = get_cmt_local_versions(cmt_global_version) # get the position algorithm we are using # I feel like this should be easier... posrec_option = "default_reconstruction_algorithm" if posrec_option in context.config: posrec_algo = context.config[posrec_option] else: posrec_algo = ( context._plugin_class_registry["peak_positions"].takes_config[posrec_option].default ) cmt_options = straxen.get_corrections.get_cmt_options(context) # catch here global versions that are not compatible with this straxen version # this happens if a new correction was added to CMT that was not used in a fixed version # we want this error to occur in order to keep fixed global versions cmt_config = dict() failed_keys: List[str] = [] for option, option_info in cmt_options.items(): # name of the CMT correction, this is not always equal to the strax option correction_name = option_info["correction"] # actual config option # this could be either a CMT tuple or a URLConfig value = option_info["strax_option"] # might need to modify correction name to include position reconstruction algo # this is a bit of a mess, but posrec configs are treated differently in the tuples # URL configs should already include the posrec suffix # (it's real mess -- we should drop tuple configs) if correction_name in posrec_corrections_basenames: correction_name += f"_{posrec_algo}" # type: ignore # now see if our correction is in our local_versions dict if correction_name in local_versions: if isinstance(value, str) and "cmt://" in value: new_value = replace_url_version(value, local_versions[correction_name]) # if it is a tuple, make a new tuple else: new_value = (value[0], local_versions[correction_name], value[2]) else: if correction_name not in failed_keys: failed_keys.append(correction_name) # type: ignore continue cmt_config[option] = new_value if len(failed_keys): failed_keys = ", ".join(failed_keys) # type: ignore msg = ( f"CMT version {cmt_global_version} is not compatible with this straxen version! " f"CMT {cmt_global_version} is missing these corrections: {failed_keys}" ) # only raise a warning if we are working with the online context if cmt_global_version == "global_ONLINE": warnings.warn(msg, UserWarning) else: raise CMTVersionError(msg) context.set_config(cmt_config) @strax.Context.add_method def apply_xedocs_configs(context: strax.Context, db="straxen_db", **kwargs) -> None: import xedocs if isinstance(db, str): func = getattr(xedocs.databases, db) db_kwargs = straxen.filter_kwargs(func, kwargs) db = func(**db_kwargs) filter_kwargs = {k: v for k, v in kwargs.items() if k in db.context_configs.schema.__fields__} docs = db.context_configs.find_docs(**filter_kwargs) global_config = {doc.config_name: doc.value for doc in docs} if len(global_config): context.set_config(global_config) else: warnings.warn( f"Could not find any context configs matchin {filter_kwargs}", RuntimeWarning, stacklevel=2, ) def replace_url_version(url, version): """Replace the local version of a correction in a CMT config.""" kwargs = {k: v[0] for k, v in parse_qs(urlparse(url).query).items()} kwargs["version"] = version args = [f"{k}={v}" for k, v in kwargs.items()] args_str = "&".join(args) return f"{url[:args_idx(url)]}?{args_str}"