Source code for straxen.get_corrections

import warnings

import numpy as np
import strax
import straxen
import typing as ty
from functools import wraps
from straxen.corrections_services import corrections_w_file
from straxen.corrections_services import single_value_corrections
from straxen.corrections_services import arrays_corrections
from straxen.corrections_services import dict_corrections


export, __all__ = strax.exporter()


def correction_options(get_correction_function):
    """
    A wrapper function for functions here in the get_corrections module
    Search for special options like ["cmt_run_id", "prefix", "suffix"]
    and apply arg shuffling accordingly
    Example confs:
        ('cmt_run_id', cmt_run_id, 'to_pe_model', 'ONLINE', True)
        ('suffix', suffix, 'cmt_run_id', cmt_run_id, 'to_pe_model', 'ONLINE', True)

    :param get_correction_function: A function here in the get_corrections module
    :return: The function wrapped with the option search
    """

    @wraps(get_correction_function)
    def correction_options_wrapped(run_id, conf, *arg):
        if isinstance(conf, tuple):
            set_prefix = ["prefix", False, None]
            set_suffix = ["suffix", False, None]
            set_cmt_run_id = ["cmt_run_id", False, None]

            for tag in [set_cmt_run_id, set_prefix, set_suffix]:
                if tag[0] in conf:
                    i_tag = conf.index(tag[0])
                    tag[:] = [tag[0], True, conf[i_tag + 1]]
                    conf = [item for i, item in enumerate(conf) if i not in [i_tag, i_tag + 1]]

            if len(conf) == 1:
                conf = conf[0]
            else:
                if set_prefix[1]:
                    conf[0] = set_prefix[2] + "_" + conf[0]
                if set_suffix[1]:
                    conf[0] = conf[0] + "_" + set_suffix[2]
                if set_cmt_run_id[1]:
                    run_id = set_cmt_run_id[2]

                conf = tuple(conf)

        # Else use the get corrections as they are
        return get_correction_function(run_id, conf, *arg)

    return correction_options_wrapped


[docs]@export @correction_options def get_correction_from_cmt(run_id, conf): """ Get correction from CMT general format is conf = ('correction_name', 'version', True) where True means looking at nT runs, e.g. get_correction_from_cmt(run_id, conf[:2]) special cases: version can be replaced by constant int, float or array when user specify value(s) :param run_id: run id from runDB :param conf: configuration :return: correction value(s) """ if isinstance(conf, str): # Legacy support for pax files return conf elif isinstance(conf, tuple) and len(conf) == 2: model_conf, cte_value = conf[:2] # special case constant to_pe values should be covered by legacy protocols? from straxen.legacy.xenon1t_url_configs import FIXED_TO_PE if model_conf in FIXED_TO_PE: warnings.warn(f"Don't load like this, but via legacy config") correction = FIXED_TO_PE[model_conf] return correction # special case constant single value or list of values. elif "constant" in model_conf: if not isinstance(cte_value, (float, int, str, list, tuple)): raise ValueError( f"User specify a model type {model_conf} " "and should provide a number or list of numbers. Got: " f"{type(cte_value)}" ) correction = cte_value return correction elif isinstance(conf, tuple) and len(conf) == 3: model_conf, global_version, is_nt = conf[:3] cmt = straxen.CorrectionsManagementServices(is_nt=is_nt) correction = cmt.get_corrections_config(run_id, conf[:2]) if correction.size == 0: raise ValueError( f"Could not find a value for {model_conf} please check it is implemented in CMT. " ) if model_conf in corrections_w_file: # file's name (maps, NN, etc) correction = " ".join(map(str, correction)) return correction if model_conf in dict_corrections: return correction[0] elif model_conf in single_value_corrections: if "samples" in model_conf: # int baseline samples, etc return int(correction) else: return float(correction) # float elife, drift velocity, etc elif model_conf in arrays_corrections: np_correction = correction.reshape(correction.size) np_correction = np_correction.astype( np.int16 ) # not sure if straxen can handle dtype:object therefore specify dtype return np_correction return correction else: raise ValueError( "Wrong configuration. " "Please use the following format: " "(config->str, model_config->str or number, is_nT->bool) " f"User specify {conf} please modify" )
[docs]@export def get_cmt_resource(run_id, conf, fmt=""): """Get resource with CMT correction file name.""" return straxen.get_resource(get_correction_from_cmt(run_id, conf), fmt=fmt)
[docs]@export def is_cmt_option(config): """Check if the input configuration is cmt style.""" return _is_cmt_option(None, config)
@correction_options def _is_cmt_option(run_id, config): # Compatibilty with URLConfig if isinstance(config, str) and "cmt://" in config: return True is_cmt = ( isinstance(config, tuple) and len(config) == 3 and isinstance(config[0], str) and isinstance(config[1], (str, int, float)) and isinstance(config[2], bool) ) return is_cmt def get_cmt_options(context: strax.Context) -> ty.Dict[str, ty.Dict[str, tuple]]: """Function which loops over all plugin configs and returns dictionary with option name as key and a nested dict of CMT correction name and strax option as values. :param context: Context with registered plugins. """ cmt_options = {} runid_test_str = "0000" for data_type, plugin in context._plugin_class_registry.items(): for option_key, option in plugin.takes_config.items(): if option_key in cmt_options: # let's not do work twice if needed by > 1 plugin continue if option_key in context.config and is_cmt_option(context.config[option_key]): opt = context.config[option_key] elif is_cmt_option(option.default): opt = option.default else: continue # check if it's a URLConfig if isinstance(opt, str) and "cmt://" in opt: before_cmt, cmt, after_cmt = opt.partition("cmt://") p = context._get_plugins((data_type,), runid_test_str)[data_type] context._set_plugin_config(p, runid_test_str, tolerant=False) del p.run_id p.config[option_key] = after_cmt try: correction_name = getattr(p, option_key) except AttributeError: # make sure the correction name does not depend on runid raise RuntimeError( "Correction names should not depend on runids! " f"Please check your option for {option_key}" ) # if there is no other protocol being called before cmt, # we will get a string back including the query part if option.QUERY_SEP in correction_name: correction_name, _ = option.split_url_kwargs(correction_name) cmt_options[option_key] = { "correction": correction_name, "strax_option": opt, } else: cmt_options[option_key] = { "correction": opt[0], "strax_option": opt, } return cmt_options