Source code for straxen.analyses.event_display

from datetime import datetime

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pytz
import strax
import straxen

export, __all__ = strax.exporter()

# Default attributes to display in the event_display (looks little
# complicated but just repeats same fields for S1 S1)
# Should be of form as below where {v} wil be filled with the value of
# event['key']:
#  (('key', '{v} UNIT'), ..)
PEAK_DISPLAY_DEFAULT_INFO: list = sum(
    [
        [
            (k.format(i=s_i), u)
            for k, u in (
                ("cs{i}", "{v:.2f} PE"),
                ("s{i}_area", "{v:.2f} PE"),
                ("alt_cs{i}", "{v:.2f} PE"),
                ("s{i}_n_channels", "{v}"),
                ("s{i}_area_fraction_top", "{v:.2f}"),
                ("s{i}_range_50p_area", "{v:.1f}"),
            )
        ]
        for s_i in (1, 2)
    ],
    [],
)
EVENT_DISPLAY_DEFAULT_INFO = (
    ("time", "{v} ns"),
    ("endtime", "{v} ns"),
    ("event_number", "{v}"),
    ("x", "{v:.2f} cm"),
    ("y", "{v:.2f} cm"),
    ("z", "{v:.2f} cm"),
    ("r", "{v:.2f} cm"),
    ("theta", "{v:.2f} rad"),
    ("drift_time", "{v} ns"),
    ("alt_s1_interaction_drift_time", "{v} ns"),
    ("alt_s2_interaction_drift_time", "{v} ns"),
)


# Don't be smart with the arguments, since it is a minianalyses we
# need to have all the arguments
@straxen.mini_analysis(requires=("event_info",))
def event_display(
    context,
    run_id,
    events,
    to_pe,
    records_matrix=True,
    s2_fuzz=50,
    s1_fuzz=0,
    max_peaks=500,
    xenon1t=False,
    s1_hp_kwargs=None,
    s2_hp_kwargs=None,
    event_time_limit=None,
    plot_all_positions=True,
    display_peak_info=PEAK_DISPLAY_DEFAULT_INFO,
    display_event_info=EVENT_DISPLAY_DEFAULT_INFO,
    simple_layout=False,
):
    """{event_docs} {event_returns}"""
    if records_matrix not in ("raw", True, False):
        raise ValueError('Choose either "raw", True or False for records_matrix')
    if (records_matrix == "raw" and not context.is_stored(run_id, "raw_records")) or (
        isinstance(records_matrix, bool) and not context.is_stored(run_id, "records")
    ):  # noqa
        print("(raw)records not stored! Not showing records_matrix")
        records_matrix = False
    # Convert string to int to allow plots to be enlarged for extra panel
    _rr_resize_int = int(bool(records_matrix))

    if simple_layout:
        axes = _event_display_simple_layout()
    else:
        axes = _event_display_full_layout(_rr_resize_int, records_matrix)

    return _event_display(
        context,
        run_id,
        events,
        to_pe,
        axes=axes,
        records_matrix=records_matrix,
        s2_fuzz=s2_fuzz,
        s1_fuzz=s1_fuzz,
        max_peaks=max_peaks,
        xenon1t=xenon1t,
        display_peak_info=display_peak_info,
        display_event_info=display_event_info,
        s1_hp_kwargs=s1_hp_kwargs,
        s2_hp_kwargs=s2_hp_kwargs,
        event_time_limit=event_time_limit,
        plot_all_positions=plot_all_positions,
    )


def _event_display(
    context,
    run_id,
    events,
    to_pe,
    axes=None,
    records_matrix=True,
    s2_fuzz=50,
    s1_fuzz=0,
    max_peaks=500,
    xenon1t=False,
    display_peak_info=PEAK_DISPLAY_DEFAULT_INFO,
    display_event_info=EVENT_DISPLAY_DEFAULT_INFO,
    s1_hp_kwargs=None,
    s2_hp_kwargs=None,
    event_time_limit=None,
    plot_all_positions=True,
):
    """{event_docs}

    :param axes: if a dict of matplotlib axes (w/ same keys as below, and empty/None for panels not
        filled) {event_returns}

    """
    if len(events) != 1:
        raise ValueError(f"Found {len(events)} only request one")
    event = events[0]

    if axes is None:
        raise ValueError(f"No axes provided")
    ax_s1 = axes.get("ax_s1", None)
    ax_s2 = axes.get("ax_s2", None)
    ax_s1_hp_t = axes.get("ax_s1_hp_t", None)
    ax_s1_hp_b = axes.get("ax_s1_hp_b", None)
    ax_s2_hp_t = axes.get("ax_s2_hp_t", None)
    ax_s2_hp_b = axes.get("ax_s2_hp_b", None)
    ax_event_info = axes.get("ax_event_info", None)
    ax_peak_info = axes.get("ax_peak_info", None)
    ax_ev = axes.get("ax_ev", None)
    ax_rec = axes.get("ax_rec", None)

    # titles
    for ax, title in zip(
        [ax_s1, ax_s1_hp_t, ax_s1_hp_b, ax_s2, ax_s2_hp_t, ax_s2_hp_b, ax_event_info, ax_peak_info],
        [
            "Main S1",
            "S1 top",
            "S1 bottom",
            "Main S2",
            "S2 top",
            "S2 bottom",
            "Event info",
            "Peak info",
        ],
    ):
        if ax is not None:
            ax.set_title(title)

    # Parse the hit pattern options
    # Convert to dict (not at function definition because of mutable defaults)
    if s1_hp_kwargs is None:
        s1_hp_kwargs = {}
    if s2_hp_kwargs is None:
        s2_hp_kwargs = {}

    # Hit patterns options:
    for hp_opt, color_map in ((s1_hp_kwargs, "Blues"), (s2_hp_kwargs, "Greens")):
        _common_opt = dict(
            xenon1t=xenon1t,
            pmt_label_color="lightgrey",
            log_scale=True,
            vmin=0.1,
            s=(250 if records_matrix else 220),
            pmt_label_size=7,
            edgecolor="grey",
            dead_pmts=np.argwhere(to_pe == 0),
            cmap=color_map,
        )
        # update s1 & S2 hit pattern kwargs with _common_opt if not
        # specified by the user
        for k, v in _common_opt.items():
            if k not in hp_opt:
                hp_opt[k] = v

    # S1
    if events["s1_area"] != 0:
        if ax_s1 is not None:
            plt.sca(ax_s1)
            context.plot_peaks(
                run_id,
                time_range=(events["s1_time"] - s1_fuzz, events["s1_endtime"] + s1_fuzz),
                single_figure=False,
            )

        # Hit pattern plots
        area = context.get_array(
            run_id,
            "peaklets",
            time_range=(events["s1_time"], events["s1_endtime"]),
            keep_columns=("area_per_channel", "time", "dt", "length"),
            progress_bar=False,
        )
        for ax, array in ((ax_s1_hp_t, "top"), (ax_s1_hp_b, "bottom")):
            if ax is not None:
                plt.sca(ax)
                straxen.plot_on_single_pmt_array(
                    c=np.sum(area["area_per_channel"], axis=0), array_name=array, **s1_hp_kwargs
                )
                # Mark reconstructed position
                plt.scatter(event["x"], event["y"], marker="X", s=100, c="k")

    # S2
    if event["s2_area"] != 0:
        if ax_s2 is not None:
            plt.sca(ax_s2)
            context.plot_peaks(
                run_id,
                time_range=(events["s2_time"] - s2_fuzz, events["s2_endtime"] + s2_fuzz),
                single_figure=False,
            )

        # Hit pattern plots
        area = context.get_array(
            run_id,
            "peaklets",
            time_range=(events["s2_time"], events["s2_endtime"]),
            keep_columns=("area_per_channel", "time", "dt", "length"),
            progress_bar=False,
        )
        for axi, (ax, array) in enumerate([(ax_s2_hp_t, "top"), (ax_s2_hp_b, "bottom")]):
            if ax is not None:
                plt.sca(ax)
                straxen.plot_on_single_pmt_array(
                    c=np.sum(area["area_per_channel"], axis=0), array_name=array, **s2_hp_kwargs
                )
                # Mark reconstructed position (corrected)
                plt.scatter(event["x"], event["y"], marker="X", s=100, c="k")
                if not xenon1t and axi == 0 and plot_all_positions:
                    _scatter_rec(event)

    # Fill panels with peak/event info
    for it, (ax, labels_and_unit) in enumerate(
        [(ax_event_info, display_event_info), (ax_peak_info, display_peak_info)]
    ):
        if ax is not None:
            for i, (_lab, _unit) in enumerate(labels_and_unit):
                coord = 0.01, 0.9 - 0.9 * i / len(labels_and_unit)
                ax.text(*coord, _lab[:24], va="top", zorder=-10)
                ax.text(coord[0] + 0.5, coord[1], _unit.format(v=event[_lab]), va="top", zorder=-10)
                # Remove axes and labels from panel
                ax.set_xticks([])
                ax.set_yticks([])
                _ = [s.set_visible(False) for s in ax.spines.values()]

    # Plot peaks in event
    ev_range = None
    if ax_ev is not None:
        plt.sca(ax_ev)
        if event_time_limit is None:
            time_range = (events["time"], events["endtime"])
        else:
            time_range = event_time_limit

        context.plot_peaks(
            run_id, time_range=time_range, show_largest=max_peaks, single_figure=False
        )
        ev_range = plt.xlim()

    if records_matrix and ax_rec is not None:
        plt.sca(ax_rec)
        context.plot_records_matrix(
            run_id,
            raw=records_matrix == "raw",
            time_range=(events["time"], events["endtime"]),
            single_figure=False,
        )
        ax_rec.tick_params(axis="x", rotation=0)
        if not xenon1t:
            # Top vs bottom division
            ax_rec.axhline(straxen.n_top_pmts, c="k")
        if ev_range is not None:
            plt.xlim(*ev_range)

    # Final tweaks
    if ax_s2 is not None:
        ax_s1.tick_params(axis="x", rotation=45)
    if ax_s2 is not None:
        ax_s1.tick_params(axis="x", rotation=45)
    if ax_ev is not None:
        ax_ev.tick_params(axis="x", rotation=0)
    title = (
        f"Run {run_id}. Time "
        f"{str(events['time'])[:-9]}.{str(events['time'])[-9:]}\n"
        f"{datetime.fromtimestamp(event['time'] / 1e9, tz=pytz.utc)}"
    )
    plt.suptitle(title, y=0.95)
    # NB: reflects panels order
    return (
        ax_s1,
        ax_s2,
        ax_s1_hp_t,
        ax_s1_hp_b,
        ax_event_info,
        ax_peak_info,
        ax_s2_hp_t,
        ax_s2_hp_b,
        ax_ev,
        ax_rec,
    )


@straxen.mini_analysis(requires=("event_info",))
def event_display_simple(context, run_id, events, **kwargs):
    raise NotImplementedError("Pass st.event_display(.., simple_layout=True)")


def _event_display_simple_layout() -> dict:
    """Setup a simple gidspec for the event display."""
    fig = plt.figure(figsize=(12, 8), facecolor="white")
    grid = plt.GridSpec(2, 3, hspace=0.5)
    axes = dict()
    axes["ax_s1"] = fig.add_subplot(grid[0, 0])
    axes["ax_s2"] = fig.add_subplot(grid[0, 1])
    axes["ax_s2_hp_t"] = fig.add_subplot(grid[0, 2])
    axes["ax_ev"] = fig.add_subplot(grid[1, :])
    return axes


def _event_display_full_layout(_rr_resize_int, records_matrix) -> dict:
    """Setup the full gidspec for the event display."""
    fig = plt.figure(figsize=(25, 21 if _rr_resize_int else 16), facecolor="white")
    grid = plt.GridSpec(
        (2 + _rr_resize_int),
        1,
        hspace=0.1 + 0.1 * _rr_resize_int,
        height_ratios=[1.5, 0.5, 0.5][: 2 + _rr_resize_int],
    )

    # S1, S2, hitpatterns
    gss_0 = gridspec.GridSpecFromSubplotSpec(2, 4, subplot_spec=grid[0], wspace=0.25, hspace=0.4)
    ax_s1 = fig.add_subplot(gss_0[0])
    ax_s2 = fig.add_subplot(gss_0[1])
    ax_s1_hp_t = fig.add_subplot(gss_0[2])
    ax_s1_hp_b = fig.add_subplot(gss_0[3])
    ax_s2_hp_t = fig.add_subplot(gss_0[6])
    ax_s2_hp_b = fig.add_subplot(gss_0[7])

    # Peak & event info
    ax_event_info = fig.add_subplot(gss_0[4])
    ax_peak_info = fig.add_subplot(gss_0[5])

    # All peaks in event
    gss_1 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid[1])
    ax_ev = fig.add_subplot(gss_1[0])
    ax_rec = None

    # (raw)records matrix (optional)
    if records_matrix:
        gss_2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid[2])
        ax_rec = fig.add_subplot(gss_2[0])
    axes = dict(
        ax_s1=ax_s1,
        ax_s2=ax_s2,
        ax_s1_hp_t=ax_s1_hp_t,
        ax_s1_hp_b=ax_s1_hp_b,
        ax_event_info=ax_event_info,
        ax_peak_info=ax_peak_info,
        ax_s2_hp_t=ax_s2_hp_t,
        ax_s2_hp_b=ax_s2_hp_b,
        ax_ev=ax_ev,
        ax_rec=ax_rec,
    )
    return axes


[docs]@export def plot_single_event(context: strax.Context, run_id, events, event_number=None, **kwargs): """Wrapper for event_display. :param context: strax.context :param run_id: run id :param events: dataframe / numpy array of events. Should either be length 1 or the event_number argument should be provided :param event_number: (optional) int, if provided, only show this event number :param kwargs: kwargs for events_display :return: see events_display """ if event_number is not None: events = events[events["event_number"] == event_number] if len(events) > 1 or len(events) == 0: raise ValueError( f"Make sure to provide an event number or a single event. Got {len(events)} events" ) return context.event_display( run_id, time_range=(events[0]["time"], events[0]["endtime"]), **kwargs )
def _scatter_rec( _event, recs=None, scatter_kwargs=None, ): """Convenient wrapper to show posrec of three algorithms for xenonnt.""" if recs is None: recs = ("mlp", "cnn", "gcn") elif len(recs) > 5: raise ValueError("I only got five markers/colors") if scatter_kwargs is None: scatter_kwargs = {} scatter_kwargs.setdefault("s", 100) scatter_kwargs.setdefault("alpha", 0.8) shapes = ("v", "^", ">", "<", "*", "D", "P") colors = ("brown", "orange", "lightcoral", "gold", "lime", "crimson") for _i, _r in enumerate(recs): x, y = _event[f"s2_x_{_r}"], _event[f"s2_y_{_r}"] if np.isnan(x) or np.isnan(y): continue plt.scatter( x, y, marker=shapes[_i], c=colors[_i], label=_r.upper(), **scatter_kwargs, ) plt.legend(loc="best", fontsize="x-small", markerscale=0.5) # Event display docstrings. # Let's add them to the corresponding functions event_docs = """ Make a waveform-display of a given event. Requires events, peaks and peaklets (optionally: records). NB: time selection should return only one event! :param context: strax.Context provided by the minianalysis wrapper :param run_id: run-id of the event :param events: events, provided by the minianalysis wrapper :param to_pe: gains, provided by the minianalysis wrapper :param records_matrix: False (no record matrix), True, or "raw" (show raw-record matrix) :param s2_fuzz: extra time around main S2 [ns] :param s1_fuzz: extra time around main S1 [ns] :param max_peaks: max peaks for plotting in the wf plot :param xenon1t: True: is 1T, False: is nT :param display_peak_info: tuple, items that will be extracted from event and displayed in the event info panel see above for format :param display_event_info: tuple, items that will be extracted from event and displayed in the peak info panel see above for format :param s1_hp_kwargs: dict, optional kwargs for S1 hitpatterns :param s2_hp_kwargs: dict, optional kwargs for S2 hitpatterns :param event_time_limit = overrides x-axis limits of event plot :param plot_all_positions if True, plot best-fit positions from all posrec algorithms """ event_returns = """ :return: axes used for plotting: ax_s1, ax_s2, ax_s1_hp_t, ax_s1_hp_b, ax_event_info, ax_peak_info, ax_s2_hp_t, ax_s2_hp_b, ax_ev, ax_rec Where those panels (axes) are: - ax_s1, main S1 peak - ax_s2, main S2 peak - ax_s1_hp_t, S1 top hit pattern - ax_s1_hp_b, S1 bottom hit pattern - ax_s2_hp_t, S2 top hit pattern - ax_s2_hp_b, S2 bottom hit pattern - ax_event_info, text info on the event - ax_peak_info, text info on the main S1 and S2 - ax_ev, waveform of the entire event - ax_rec, (raw)record matrix (if any otherwise None) """ # Add the same docstring to each of these functions for event_function in (event_display, event_display_simple, _event_display): doc = event_function.__doc__ if doc is not None: event_function.__doc__ = doc.format(event_docs=event_docs, event_returns=event_returns)