import warnings
import numpy as np
import strax
import straxen
import typing as ty
from functools import wraps
from straxen.corrections_services import corrections_w_file
from straxen.corrections_services import single_value_corrections
from straxen.corrections_services import arrays_corrections
from straxen.corrections_services import dict_corrections
export, __all__ = strax.exporter()
def correction_options(get_correction_function):
"""
A wrapper function for functions here in the get_corrections module
Search for special options like ["cmt_run_id", "prefix", "suffix"]
and apply arg shuffling accordingly
Example confs:
('cmt_run_id', cmt_run_id, 'to_pe_model', 'ONLINE', True)
('suffix', suffix, 'cmt_run_id', cmt_run_id, 'to_pe_model', 'ONLINE', True)
:param get_correction_function: A function here in the get_corrections module
:return: The function wrapped with the option search
"""
@wraps(get_correction_function)
def correction_options_wrapped(run_id, conf, *arg):
if isinstance(conf, tuple):
set_prefix = ["prefix", False, None]
set_suffix = ["suffix", False, None]
set_cmt_run_id = ["cmt_run_id", False, None]
for tag in [set_cmt_run_id, set_prefix, set_suffix]:
if tag[0] in conf:
i_tag = conf.index(tag[0])
tag[:] = [tag[0], True, conf[i_tag + 1]]
conf = [item for i, item in enumerate(conf) if i not in [i_tag, i_tag + 1]]
if len(conf) == 1:
conf = conf[0]
else:
if set_prefix[1]:
conf[0] = set_prefix[2] + "_" + conf[0]
if set_suffix[1]:
conf[0] = conf[0] + "_" + set_suffix[2]
if set_cmt_run_id[1]:
run_id = set_cmt_run_id[2]
conf = tuple(conf)
# Else use the get corrections as they are
return get_correction_function(run_id, conf, *arg)
return correction_options_wrapped
[docs]@export
@correction_options
def get_correction_from_cmt(run_id, conf):
"""
Get correction from CMT general format is
conf = ('correction_name', 'version', True)
where True means looking at nT runs, e.g.
get_correction_from_cmt(run_id, conf[:2])
special cases:
version can be replaced by constant int, float or array
when user specify value(s)
:param run_id: run id from runDB
:param conf: configuration
:return: correction value(s)
"""
if isinstance(conf, str):
# Legacy support for pax files
return conf
elif isinstance(conf, tuple) and len(conf) == 2:
model_conf, cte_value = conf[:2]
# special case constant to_pe values should be covered by legacy protocols?
from straxen.legacy.xenon1t_url_configs import FIXED_TO_PE
if model_conf in FIXED_TO_PE:
warnings.warn(f"Don't load like this, but via legacy config")
correction = FIXED_TO_PE[model_conf]
return correction
# special case constant single value or list of values.
elif "constant" in model_conf:
if not isinstance(cte_value, (float, int, str, list, tuple)):
raise ValueError(
f"User specify a model type {model_conf} "
"and should provide a number or list of numbers. Got: "
f"{type(cte_value)}"
)
correction = cte_value
return correction
elif isinstance(conf, tuple) and len(conf) == 3:
model_conf, global_version, is_nt = conf[:3]
cmt = straxen.CorrectionsManagementServices(is_nt=is_nt)
correction = cmt.get_corrections_config(run_id, conf[:2])
if correction.size == 0:
raise ValueError(
f"Could not find a value for {model_conf} please check it is implemented in CMT. "
)
if model_conf in corrections_w_file: # file's name (maps, NN, etc)
correction = " ".join(map(str, correction))
return correction
if model_conf in dict_corrections:
return correction[0]
elif model_conf in single_value_corrections:
if "samples" in model_conf: # int baseline samples, etc
return int(correction)
else:
return float(correction) # float elife, drift velocity, etc
elif model_conf in arrays_corrections:
np_correction = correction.reshape(correction.size)
np_correction = np_correction.astype(
np.int16
) # not sure if straxen can handle dtype:object therefore specify dtype
return np_correction
return correction
else:
raise ValueError(
"Wrong configuration. "
"Please use the following format: "
"(config->str, model_config->str or number, is_nT->bool) "
f"User specify {conf} please modify"
)
[docs]@export
def get_cmt_resource(run_id, conf, fmt=""):
"""Get resource with CMT correction file name."""
return straxen.get_resource(get_correction_from_cmt(run_id, conf), fmt=fmt)
[docs]@export
def is_cmt_option(config):
"""Check if the input configuration is cmt style."""
return _is_cmt_option(None, config)
@correction_options
def _is_cmt_option(run_id, config):
# Compatibilty with URLConfig
if isinstance(config, str) and "cmt://" in config:
return True
is_cmt = (
isinstance(config, tuple)
and len(config) == 3
and isinstance(config[0], str)
and isinstance(config[1], (str, int, float))
and isinstance(config[2], bool)
)
return is_cmt
def get_cmt_options(context: strax.Context) -> ty.Dict[str, ty.Dict[str, tuple]]:
"""Function which loops over all plugin configs and returns dictionary with option name as key
and a nested dict of CMT correction name and strax option as values.
:param context: Context with registered plugins.
"""
cmt_options = {}
runid_test_str = "0000"
for data_type, plugin in context._plugin_class_registry.items():
for option_key, option in plugin.takes_config.items():
if option_key in cmt_options:
# let's not do work twice if needed by > 1 plugin
continue
if option_key in context.config and is_cmt_option(context.config[option_key]):
opt = context.config[option_key]
elif is_cmt_option(option.default):
opt = option.default
else:
continue
# check if it's a URLConfig
if isinstance(opt, str) and "cmt://" in opt:
before_cmt, cmt, after_cmt = opt.partition("cmt://")
p = context._get_plugins((data_type,), runid_test_str)[data_type]
context._set_plugin_config(p, runid_test_str, tolerant=False)
del p.run_id
p.config[option_key] = after_cmt
try:
correction_name = getattr(p, option_key)
except AttributeError:
# make sure the correction name does not depend on runid
raise RuntimeError(
"Correction names should not depend on runids! "
f"Please check your option for {option_key}"
)
# if there is no other protocol being called before cmt,
# we will get a string back including the query part
if option.QUERY_SEP in correction_name:
correction_name, _ = option.split_url_kwargs(correction_name)
cmt_options[option_key] = {
"correction": correction_name,
"strax_option": opt,
}
else:
cmt_options[option_key] = {
"correction": opt[0],
"strax_option": opt,
}
return cmt_options