Source code for stytra.gui.multiscope

import datetime
from collections import namedtuple
from stytra.collectors.accumulators import DataFrameAccumulator

import colorspacious
import numpy as np
import pyqtgraph as pg
from PyQt5.QtGui import QFont, QPalette, QColor
from PyQt5.QtWidgets import (
    QWidget,
    QVBoxLayout,
    QHBoxLayout,
    QPushButton,
    QLabel,
    QDoubleSpinBox,
    QSpacerItem,
    QGroupBox,
    QCheckBox,
    QSizePolicy,
)

PlotTuple = namedtuple(
    "PlotTuple", ["curve", "curve_label", "min_label", "max_label", "value_label"]
)


[docs]class MultiStreamPlot(QWidget): """Window to plot live data that are accumulated by a DAtaAccumulator object. New plots can be added via the add_stream() method. Parameters ---------- Returns ------- """ def __init__( self, time_past=5, bounds_update=0.1, compact=False, n_points_max=500, accumulators=None, precision=None, experiment=None, *args, **kwargs ): super().__init__(*args, **kwargs) self.experiment = experiment self.time_past = time_past self.compact = compact self.n_points_max = n_points_max self.setLayout(QVBoxLayout()) self.layout().setContentsMargins(0, 0, 0, 0) self.precision = precision or 3 self.penwidth = 2 if not compact: self.control_layout = QHBoxLayout() self.control_layout.setContentsMargins(0, 0, 0, 0) self.btn_select = QPushButton("Choose variables") self.btn_select.clicked.connect(self.show_select) self.control_layout.addWidget(self.btn_select) self.wnd_config = None self.btn_freeze = QPushButton() self.btn_freeze.setMinimumSize(80, 16) self.btn_freeze.clicked.connect(self.toggle_freeze) self.control_layout.addWidget(self.btn_freeze) try: tm = self.experiment.tracking_method_name if tm == "tail" or tm == "fish": self.btn_extra = QPushButton( "Show tail curvature" if tm == "tail" else "Show last bouts" ) self.btn_extra.clicked.connect(self.show_extra_plot) self.control_layout.addWidget(self.btn_extra) except AttributeError: pass self.lbl_zoom = QLabel("Plot past ") self.spn_zoom = QDoubleSpinBox() self.spn_zoom.setValue(time_past) self.spn_zoom.setSuffix("s") self.spn_zoom.setMinimum(0.1) self.spn_zoom.setMaximum(30) self.spn_zoom.valueChanged.connect(self.update_zoom) self.control_layout.addItem( QSpacerItem(0, 0, QSizePolicy.Expanding, QSizePolicy.Minimum) ) self.control_layout.addWidget(self.lbl_zoom) self.control_layout.addWidget(self.spn_zoom) self.layout().addLayout(self.control_layout) self.plotContainer = pg.PlotWidget() self.plotContainer.showAxis("left", False) self.plotContainer.plotItem.hideButtons() self.replay_left = pg.InfiniteLine( -1, pen=(220, 220, 220), movable=True, hoverPen=(230, 30, 0) ) self.replay_right = pg.InfiniteLine( -1, pen=(220, 220, 220), movable=True, hoverPen=(230, 30, 0) ) for rep_line in [self.replay_left, self.replay_right]: rep_line.sigDragged.connect(self.update_replay_limits) self.layout().addWidget(self.plotContainer) self.active_plots = [] self.accumulators = accumulators or [] self.selected_columns = [] self.stream_items = [] self.stream_scales = [] self.bounds = [] self.bounds_update = bounds_update self.colors = [] self.frozen = True self.bounds_visible = None # trick to set color on update self.color_set = False self.toggle_freeze() self.update_zoom(time_past) self.update_buflen(time_past)
[docs] @staticmethod def get_colors(n_colors=1, lightness=50, saturation=50, shift=0): """Get colors on the LCh ring Parameters ---------- n_colors : param lightness: (Default value = 1) lightness : (Default value = 50) saturation : (Default value = 50) shift : (Default value = 0) Returns ------- """ hues = np.linspace(0, 360, n_colors + 1)[:-1] + shift return ( np.clip( colorspacious.cspace_convert( np.stack( [ np.ones_like(hues) * lightness, np.ones_like(hues) * saturation, hues, ], 1, ), "CIELCh", "sRGB1", ), 0, 1, ) * 255 )
[docs] def refresh_streams(self): for accumulator, sel_columns in self.accumulators: pass
[docs] def add_stream(self, accumulator: DataFrameAccumulator, header_items=None): """Adds a data collector stream to the plot: Parameters ---------- accumulator : instance of the DataAccumulator class header_items : specify elements in the DataAccumulator to be plot by their header name. Returns ------- """ try: if header_items is None: if accumulator.plot_columns is not None: header_items = accumulator.plot_columns else: header_items = accumulator.columns[1:] # first column is always t self.colors = self.get_colors(len(self.stream_items) + len(header_items)) self.accumulators.append(accumulator) self.selected_columns.append(header_items) except ValueError: return self.bounds.append(None) i_curve = len(self.stream_items) for header_item in header_items: c = pg.PlotCurveItem( x=np.array([0]), y=np.array([i_curve]), connect="finite" ) curve_label = pg.TextItem(header_item, anchor=(0, 1)) curve_label.setPos(-self.time_past * 0.9, i_curve) value_label = pg.TextItem("", anchor=(0, 0.5)) font_bold = QFont("Sans Serif", 8) font_bold.setBold(True) value_label.setFont(font_bold) value_label.setPos(0, i_curve + 0.5) max_label = pg.TextItem("", anchor=(0, 0)) max_label.setPos(0, i_curve + 1) min_label = pg.TextItem("", anchor=(0, 1)) min_label.setPos(0, i_curve) self.stream_items.append( PlotTuple(c, curve_label, min_label, max_label, value_label) ) i_curve += 1 for sitems, color in zip(self.stream_items, self.colors): for itm in sitems: self.plotContainer.addItem(itm) if isinstance(itm, pg.PlotCurveItem): itm.setPen(color, width=self.penwidth) else: itm.setColor(color) self.plotContainer.setYRange(-0.1, len(self.stream_items) + 0.1)
[docs] def remove_streams(self): for itmset in self.stream_items: for itm in itmset: self.plotContainer.removeItem(itm) self.stream_items = [] self.selected_columns = [] self.accumulators = [] self.bounds = []
def _set_labels(self, labels, values=None, precision=3): if values is None: txts = ["-", "-", "NaN"] else: fmt = "{:7.{prec}f}" txts = [fmt.format(x, prec=precision) for x in values] if not self.bounds_visible: txts[0] = "" txts[1] = "" for lbl, txt in zip( [labels.min_label, labels.max_label, labels.value_label], txts ): if lbl is not None: lbl.setText(txt)
[docs] def update_bounds(self, i_acc, new_bounds): if self.bounds[i_acc] is None: self.bounds[i_acc] = new_bounds else: self.bounds[i_acc] = ( self.bounds_update * new_bounds + (1 - self.bounds_update) * self.bounds[i_acc] )
[docs] def update(self): """Function called by external timer to update the plot""" if not self.color_set: self.plotContainer.setBackground(self.palette().color(QPalette.Button)) self.color_set = True if self.frozen: return None try: if self.experiment.camera_state.paused: return None except AttributeError: pass current_time = datetime.datetime.now() i_stream = 0 for i_acc, (acc, sel_cols) in enumerate( zip(self.accumulators, self.selected_columns) ): # try: # difference from data accumulator time and now in seconds: delta_t = (self.experiment.t0 - current_time).total_seconds() data_frame = acc.get_last_t(self.time_past) # if this accumulator does not have enough data to plot, skip it if data_frame is None or data_frame.shape[0] <= 1: for _ in sel_cols: self._set_labels(self.stream_items[i_stream]) self.stream_items[i_stream].curve.setData(x=[], y=[]) i_stream += 1 continue # downsampling if there are too many points if len(data_frame) > self.n_points_max: data_frame = data_frame[:: len(data_frame) // self.n_points_max] time_array = delta_t + data_frame.t.values # loop to handle nan values in a single column new_bounds = np.zeros((len(sel_cols), 2)) for id, col in enumerate(sel_cols): # Exclude nans from calculation of percentile boundaries: d = data_frame[col].values if d.dtype != np.float64: continue b = ~np.isnan(d) if np.any(b): non_nan_data = data_frame[col][b] new_bounds[id, :] = np.percentile(non_nan_data, (0.5, 99.5), 0) # if the bounds are the same, set arbitrary ones if new_bounds[id, 0] == new_bounds[id, 1]: new_bounds[id, 1] += 1 self.update_bounds(i_acc, new_bounds) for col, (lb, ub) in zip(sel_cols, self.bounds[i_acc]): scale = ub - lb if scale < 0.00001: self.stream_items[i_stream].curve.setData(x=[], y=[]) else: self.stream_items[i_stream].curve.setData( x=time_array, y=i_stream + ((data_frame[col].values - lb) / scale), ) self._set_labels( self.stream_items[i_stream], values=(lb, ub, data_frame[col].values[-1]), ) i_stream += 1
[docs] def show_extra_plot(self): self.experiment.window_main.docks["dock_extra"].setVisible(True)
def _qcolorstring(self, color): colorname = self.palette().color(QPalette.Background).name().lstrip("#") return "rgb({},{},{})".format( *(int(colorname[i * 2 : i * 2 + 2], 16) for i in range(3)) )
[docs] def toggle_freeze(self): self.frozen = not self.frozen if self.frozen: if not self.compact: self.btn_freeze.setText("Live plot") self.btn_freeze.setStyleSheet("background-color:rgb(207, 132, 5);") self.plotContainer.setBackground(QColor(69, 78, 86)) self.plotContainer.plotItem.vb.setMouseEnabled(x=True, y=True) for rep_line in [self.replay_left, self.replay_right]: self.plotContainer.addItem(rep_line) else: if not self.compact: self.btn_freeze.setText("Freeze plot") for rep_line in [self.replay_left, self.replay_right]: self.plotContainer.removeItem(rep_line) if self.color_set: self.plotContainer.setBackground(self.palette().color(QPalette.Button)) self.btn_freeze.setStyleSheet( "background-color:" + self._qcolorstring(QPalette.Mid) ) self.plotContainer.plotItem.vb.setMouseEnabled(x=False, y=False) self.plotContainer.setXRange(-self.time_past * 0.9, self.time_past * 0.05) self.plotContainer.setYRange(-0.1, len(self.stream_items) + 0.1)
[docs] def update_buflen(self, time_past): if self.experiment is not None: try: self.experiment.camera_state.ring_buffer_length = time_past except (IndexError, AttributeError): pass
[docs] def update_zoom(self, time_past=1): # we use the current zoom level and the framerate to determine the rolling buffer length self.update_buflen(time_past) self.time_past = time_past self.plotContainer.setXRange(-self.time_past * 0.9, self.time_past * 0.05) self.plotContainer.plotItem.vb.setRange( xRange=(-self.time_past * 0.9, self.time_past * 0.05) ) # shift the labels for (i_curve, items) in enumerate(self.stream_items): items.curve_label.setPos(-self.time_past * 0.9, i_curve)
[docs] def update_replay_limits(self): if self.experiment is not None: try: left_lim = self.replay_left.getXPos() right_lim = self.replay_right.getXPos() self.experiment.camera_state.replay_limits = ( min(left_lim, right_lim), max(left_lim, right_lim), ) except AttributeError: pass
[docs] def show_select(self): self.wnd_config = StreamPlotConfig(self) self.wnd_config.show()
[docs]class StreamPlotConfig(QWidget): """ Widget for configuring streaming plots """ def __init__(self, sp: MultiStreamPlot): super().__init__() self.sp = sp self.main_layout = QVBoxLayout() self.setLayout(self.main_layout) self.accs = sp.accumulators self.checkboxes = [] for ac, sel_col in zip(sp.accumulators, sp.selected_columns): acccheck = [] gb = QGroupBox(ac.name) gb.setLayout(QVBoxLayout()) for i_it, item in enumerate(ac.columns[1:]): chk = QCheckBox(item) chk.setChecked(item in sel_col) chk.stateChanged.connect(self.refresh_plots) acccheck.append(chk) gb.layout().addWidget(chk) self.checkboxes.append(acccheck) self.main_layout.addWidget(gb)
[docs] def refresh_plots(self): self.sp.remove_streams() for chkboxes, ac in zip(self.checkboxes, self.accs): sel_headers = [] for item, chk in zip(ac.columns[1:], chkboxes): if chk.isChecked(): sel_headers.append(item) self.sp.add_stream(ac, sel_headers)
[docs]class FrameratePlot(MultiStreamPlot): def __init__(self, *args, round_bounds=0.1, framerate_limits=None, **kwargs): super().__init__(*args, **kwargs) self.framerate_limits = framerate_limits or dict() self.round_bounds = round_bounds
[docs] def update(self): super().update() for acc in self.accumulators: lim = self.framerate_limits.get(acc.name, None) if ( lim is not None and len(acc.stored_data) > 0 and acc.stored_data[-1][0] < lim ): print("BAD ", acc.name)
def _round_bounds(self, bounds): rounded = np.stack( [ np.floor(bounds[:, 0] / self.round_bounds) * self.round_bounds, np.ceil(bounds[:, 1] / self.round_bounds) * self.round_bounds, ], 1, ) if self.round_bounds >= 1: return rounded.astype(np.int32) else: return rounded def _update_round_bounds(self, old_bounds, new_bounds, tolerance=0.1): """ If bounds are exceeed by tolerance Parameters ---------- old_bounds new_bounds Returns ------- """ to_update = np.any( np.abs(old_bounds - new_bounds) > tolerance * np.abs(old_bounds), 1 ) old_bounds[to_update, :] = self._round_bounds(new_bounds[to_update, :]) return old_bounds
[docs] def update_bounds(self, i_acc, new_bounds): if self.bounds[i_acc] is None: self.bounds[i_acc] = self._round_bounds(new_bounds) else: self.bounds[i_acc] = self._update_round_bounds( self.bounds[i_acc], new_bounds )