import sys
from sys import getsizeof, stderr
import inspect
import warnings
import datetime
from immutabledict import immutabledict
import pytz
from itertools import chain
from collections import defaultdict, OrderedDict, deque
from importlib import import_module
from platform import python_version
import typing as ty
from copy import deepcopy
import numpy as np
import pandas as pd
import socket
import graphviz
import strax
import straxen
from git import Repo, InvalidGitRepositoryError
from configparser import NoSectionError
try:
# pylint: disable=redefined-builtin
from reprlib import repr
except ImportError:
pass
_is_jupyter = any("jupyter" in arg for arg in sys.argv)
export, __all__ = strax.exporter()
__all__.extend(["kind_colors"])
kind_colors = dict(
events="#ffffff",
peaks="#98fb98",
hitlets="#0066ff",
peaklets="#d9ff66",
merged_s2s="#ccffcc",
lone_hits="#CAFF70",
records="#ffa500",
raw_records="#ff4500",
raw_records_aqmon="#ff4500",
raw_records_aux_mv="#ff4500",
online_peak_monitor="deepskyblue",
online_monitor="deepskyblue",
)
[docs]@export
def dataframe_to_wiki(df, float_digits=5, title="Awesome table", force_int: ty.Tuple = ()):
"""Convert a pandas dataframe to a dokuwiki table (which you can copy-paste onto the XENON wiki)
:param df: dataframe to convert
:param float_digits: format float to this number of digits.
:param title: title of the table.
:param force_int: tuple of column names to force to be integers
"""
table = "^ %s " % title + "^" * (len(df.columns) - 1) + "^\n"
table += "^ " + " ^ ".join(df.columns) + " ^\n"
def format_float(x):
if isinstance(x, float):
return f"{x:.{float_digits}f}"
return x
force_int = np.where(np.in1d(df.columns.values, strax.to_str_tuple(force_int)))[0]
for _, row in df.iterrows():
table += (
"| "
+ " | ".join(
[
str(int(x) if i in force_int else format_float(x))
for i, x in enumerate(row.values.tolist())
]
)
+ " |\n"
)
return table
[docs]@export
def print_versions(
modules=("strax", "straxen", "cutax"),
print_output=not _is_jupyter,
include_python=True,
return_string=False,
include_git=True,
):
"""Print versions of modules installed.
:param modules: Modules to print, should be str, tuple or list. E.g.
print_versions(modules=('numpy', 'dddm',))
:param return_string: optional. Instead of printing the message, return a string
:param include_git: Include the current branch and latest commit hash
:return: optional, the message that would have been printed
"""
versions = defaultdict(list)
if include_python:
versions["module"] = ["python"]
versions["version"] = [python_version()]
versions["path"] = [sys.executable]
versions["git"] = [None]
for m in strax.to_str_tuple(modules):
result = _version_info_for_module(m, include_git=include_git)
if result is None:
continue
version, path, git_info = result
versions["module"].append(m)
versions["version"].append(version)
versions["path"].append(path)
versions["git"].append(git_info)
df = pd.DataFrame(versions)
info = f"Host {socket.getfqdn()}\n{df.to_string(index=False,)}"
if print_output:
print(info)
if return_string:
return info
return df
def _version_info_for_module(module_name, include_git):
try:
mod = import_module(module_name)
except (ModuleNotFoundError, ImportError):
print(f"{module_name} is not installed")
return
git = None
version = mod.__dict__.get("__version__", None)
module_path = mod.__dict__.get("__path__", [None])[0]
if include_git:
try:
repo = Repo(module_path, search_parent_directories=True)
except InvalidGitRepositoryError:
# not a git repo
pass
else:
try:
branch = repo.active_branch
except TypeError:
branch = "unknown"
try:
commit_hash = repo.head.object.hexsha
except TypeError:
commit_hash = "unknown"
git = f"branch:{branch} | {commit_hash[:7]}"
return version, module_path, git
@strax.Context.add_method
def extract_latest_comment(self):
"""Extract the latest comment in the runs-database. This just adds info to st.runs.
Example:
st.extract_latest_comment()
st.select_runs(available=('raw_records'))
"""
if self.runs is None or "comments" not in self.runs.keys():
self.scan_runs(store_fields=("comments",))
latest_comments = _parse_to_last_comment(self.runs["comments"])
self.runs["comments"] = latest_comments
return self.runs
def _parse_to_last_comment(comments):
"""Unpack to get the last comment (hence the -1) or give '' when there is none."""
return [(c[-1]["comment"] if hasattr(c, "__len__") else "") for c in comments]
[docs]@export
def convert_array_to_df(array: np.ndarray) -> pd.DataFrame:
"""Converts the specified array into a DataFrame drops all higher dimensional fields during the
process.
:param array: numpy.array to be converted.
:return: DataFrame with higher dimensions dropped.
"""
keys = [key for key in array.dtype.names if array[key].ndim == 1]
return pd.DataFrame(array[keys])
[docs]@export
def filter_kwargs(func, kwargs):
"""Filter out keyword arguments that are not in the call signature of func and return filtered
kwargs dictionary."""
params = inspect.signature(func).parameters
if any([str(p).startswith("**") for p in params.values()]):
# if func accepts wildcard kwargs, return all
return kwargs
return {k: v for k, v in kwargs.items() if k in params}
[docs]@export
class CacheDict(OrderedDict):
"""Dict with a limited length, ejecting LRUs as needed.
copied from
https://gist.github.com/davesteele/44793cd0348f59f8fadd49d7799bd306
"""
def __init__(self, *args, cache_len: int = 10, **kwargs):
assert cache_len > 0
self.cache_len = cache_len
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
super().__setitem__(key, value)
super().move_to_end(key)
while len(self) > self.cache_len:
oldkey = next(iter(self))
super().__delitem__(oldkey)
def __getitem__(self, key):
val = super().__getitem__(key)
super().move_to_end(key)
return val
[docs]@export
def total_size(o, handlers=None, verbose=False):
"""Returns the approximate memory footprint an object and all of its contents.
Automatically finds the contents of the following builtin containers and
their subclasses: tuple, list, deque, dict, set and frozenset.
To search other containers, add handlers to iterate over their contents:
handlers = {SomeContainerClass: iter,
OtherContainerClass: OtherContainerClass.get_elements}
from: https://code.activestate.com/recipes/577504/
"""
dict_handler = lambda d: chain.from_iterable(d.items())
all_handlers = {
tuple: iter,
list: iter,
deque: iter,
dict: dict_handler,
set: iter,
frozenset: iter,
}
if handlers is not None:
all_handlers.update(handlers) # user handlers take precedence
seen = set() # track which object id's have already been seen
default_size = getsizeof(0) # estimate sizeof object without __sizeof__
def sizeof(o):
if id(o) in seen: # do not double count the same object
return 0
seen.add(id(o))
s = getsizeof(o, default_size)
if verbose:
print(s, type(o), repr(o), file=stderr)
for typ, handler in all_handlers.items():
if isinstance(o, typ):
s += sum(map(sizeof, handler(o)))
break
return s
return sizeof(o)
@strax.Context.add_method
def dependency_tree(
st,
target="event_info",
dump_plot=True,
to_dir="./",
format="svg",
):
st._fixed_plugin_cache = None
plugins = st._get_plugins((target,), run_id="0")
graph = graphviz.Digraph(name=f"{to_dir}/{target}", strict=True)
graph.attr(bgcolor="transparent")
for d, p in plugins.items():
graph.node(
d,
style="filled",
fillcolor=kind_colors.get(p.data_kind_for(d), "grey"),
)
for dep in p.depends_on:
graph.edge(d, dep)
# dump the plot if need
if dump_plot:
graph.render(format=format)
@strax.Context.add_method
def storage_graph(
st,
run_id,
target,
graph=None,
not_stored=None,
dump_plot=True,
to_dir="./",
format="svg",
):
"""Plot the dependency graph indicating the storage of the plugins.
:param target: str of the target plugin to check
:param graph: graphviz.graphs.Digraph instance
:param not_stored: set of plugins which are not stored
:param dump_plot: bool, if True, save the plot to the to_dir
:param to_dir: str, directory to save the plot
:param format: str, format of the plot
:return: all plugins that will be calculated when running self.make(run_id, target)
The colors used in the graph represent the following storage states:
- grey: strax.SaveWhen.NEVER
- red: strax.SaveWhen.EXPLICIT
- orange: strax.SaveWhen.TARGET
- yellow: strax.SaveWhen.ALWAYS
- green: target is stored
"""
save_when_colors = {
strax.SaveWhen.NEVER: "grey",
strax.SaveWhen.EXPLICIT: "red",
strax.SaveWhen.TARGET: "orange",
strax.SaveWhen.ALWAYS: "yellow",
}
if not_stored is None:
not_stored = set()
# the set of plugins which are not stored
if st.is_stored(run_id, target):
# if the plugin is stored, fill in green
fillcolor = "green"
else:
save_when = deepcopy(st._plugin_class_registry[target].save_when)
if isinstance(save_when, immutabledict):
save_when = save_when[target]
# if it is not stored, fill in the color according to save_when
fillcolor = save_when_colors[save_when]
if graph is None:
graph = graphviz.Digraph(name=f"{to_dir}/{run_id}-{target}", strict=True)
graph.attr(bgcolor="transparent")
else:
if not isinstance(graph, graphviz.graphs.Digraph):
raise ValueError("graph should be a graphviz.Digraph instance!")
# add a node of target
graph.node(
target,
style="filled",
fillcolor=fillcolor,
)
if (not st.is_stored(run_id, target)) and (target not in not_stored):
not_stored.add(target)
depends_on = deepcopy(st._plugin_class_registry[target].depends_on)
depends_on = strax.to_str_tuple(depends_on)
for dep in depends_on:
# only add the node to the graph but not save the plot
not_stored.update(
st.storage_graph(
run_id,
dep,
graph=graph,
not_stored=not_stored,
dump_plot=False,
to_dir=to_dir,
)
)
graph.edge(target, dep)
# dump the plot if need
if dump_plot:
graph.render(format=format)
return not_stored