Source code for stytra.gui.fishplots

from PyQt5.QtWidgets import QWidget, QVBoxLayout, QLabel, QPushButton
import pyqtgraph as pg
import numpy as np
from stytra.collectors import DataFrameAccumulator
from stytra.collectors import QueueDataAccumulator
from stytra.tracking.online_bouts import find_bouts_online, BoutState
from stytra.utilities import reduce_to_pi

from collections import deque
from lightparam import Param, Parametrized
from lightparam.gui import ParameterGui

from scipy.ndimage.filters import gaussian_filter1d

[docs]class StreamingPositionPlot(pg.GraphicsWindow): """Plot that displays the virtual position of the fish""" def __init__(self, *args, data_accumulator, n_points=500, **kwargs): super().__init__(*args, **kwargs) assert isinstance(data_accumulator, DataFrameAccumulator) self.title = "position" self.positionPlot = self.addPlot() self.positionPlot.setAspectLocked(True) self.curve = self.positionPlot.plot() self.n_points = n_points self.data_accumulator = data_accumulator self.ind_x = self.data_accumulator.columns.index("x") self.ind_y = self.data_accumulator.columns.index("y")
[docs] def update(self): """ """ try: data_array = self.data_accumulator.get_last_n(self.n_points) velocity = np.r_[ np.clip( np.diff(data_array[:, self.ind_x]) ** 2 + np.diff(data_array[:, self.ind_y]) ** 2, 0, 30, ) / 30, [0], ] self.curve.setData( x=data_array[:, self.ind_x], y=data_array[:, self.ind_y], color=np.stack( [ 0.5 + 0.5 * velocity, 0.2 + 0.8 * velocity, velocity, np.ones_like(velocity), ], 1, ), ) except (IndexError, TypeError): pass
[docs]class TailStreamPlot(QWidget): """ Displays the curvature of the tail in time using a heatmap """ def __init__(self, acc, n_points=300): super().__init__() self.title = "Tail curvature" self.acc = acc self.headers = None self.n_points = n_points self.setLayout(QVBoxLayout()) self.layout().setContentsMargins(0, 0, 0, 0) self.display_widget = pg.GraphicsLayoutWidget() self.vb_display = pg.ViewBox() self.display_widget.addItem(self.vb_display) self.image_item = pg.ImageItem() self.vb_display.addItem(self.image_item) self.image_item.setLevels((-0.6, 0.6)) self.image_item.setLookupTable( pg.ColorMap( np.linspace(0, 1, 5), np.array( [ [0.42107294, 0.80737975, 0.49219722], [0.23166242, 0.39962101, 0.32100403], [0.0, 0.0, 0.0], [0.46170494, 0.30327584, 0.38740225], [0.91677407, 0.58427975, 0.92293321], ] ), ).getLookupTable(alpha=False) ) self.layout().addWidget(self.display_widget)
[docs] def update(self): if not self.isVisible(): return data_array = self.acc.get_last_n(self.n_points) if len(data_array) > 0: self.image_item.setImage( image=np.diff(data_array[:, 2:], axis=1).T, autoLevels=False )
[docs]def rot_mat(theta): """The rotation matrix for an angle theta """ return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
[docs]def angle_mean(angles, axis=1): """Correct calculation of a mean of an array of angles""" return np.arctan2(np.sum(np.sin(angles)), np.sum(np.cos(angles)))
[docs]def normalise_bout(coord): """ Reset a bout to be facing upward and start from 0,0 Parameters ---------- bout Returns ------- """ dir_init = angle_mean(coord[:2, 2]) coord[:, :2] = (coord[:, :2] - coord[:1, :2]) @ rot_mat(dir_init + np.pi) coord[:, 2] -= dir_init coord[:, 2] = reduce_to_pi(coord[:, 2]) return coord
[docs]class BoutPlot(QWidget): """ Plots the last few bouts in normalized coordinates, with fish facing to the right. """ def __init__(self, acc: QueueDataAccumulator, i_fish=0, n_bouts=10, n_save_max=300): super().__init__() self.title = "Bout shape" self.acc = acc self.bouts = deque() self.i_fish = i_fish self.processed_index = 0 self.detection_params = Parametrized( params=dict( threshold=Param(0.2, (0.01, 5.0)), n_without_crossing=Param(5, (0, 10)), pad_before=Param(5, (0, 20)), pad_after=Param(5, (0, 20)), min_bout_len=Param(1, (1, 30)), ) ) self.n_bouts = n_bouts self.old_coords = None self.i_curve = 0 self.n_save_max = n_save_max self.setLayout(QVBoxLayout()) self.layout().setContentsMargins(0, 0, 0, 0) self.btn_editparam = QPushButton("Detection parameters") self.btn_editparam.clicked.connect(self.edit_params) self.layout().addWidget(self.btn_editparam) self.wnd_params = None self.vmax = 0 self.lbl_vmax = QLabel() self.layout().addWidget(self.lbl_vmax) self.display_widget = pg.GraphicsLayoutWidget() self.layout().addWidget(self.display_widget) self.vb_display = pg.ViewBox() self.vb_display.setAspectLocked(True, 1) self.vb_display.setRange(xRange=[-1, 5], disableAutoRange=True) self.vb_display.invertY(True) self.display_widget.addItem(self.vb_display) self.bout_curves = [ pg.PlotCurveItem(connect="finite") for _ in range(self.n_bouts) ] self.colors = np.zeros(self.n_bouts) self.decay_constant = 0.99 self.bout_coords = None self.bout_state = BoutState(0, 0.0, 0, 0, 0) for c in self.bout_curves: self.vb_display.addItem(c)
[docs] def edit_params(self): self.wnd_params = ParameterGui(self.detection_params)
[docs] def update(self): if not self.isVisible(): return current_index = len(self.acc.stored_data) if current_index == 0 or current_index < self.processed_index + 2: return # Pull the new data from the accumulator new_coords = np.array( self.acc.stored_data[ max( self.processed_index, current_index - self.n_save_max ) : current_index ] ) self.processed_index = current_index ix, iy, ith = ( self.acc.header_dict["f{:d}_{}".format(self.i_fish, var)] for var in ["x", "y", "theta"] ) new_coords = new_coords[:, [ix, iy, ith]] # if in the previous refresh we ended up inside a bout, there are still # coordinates left to process if self.old_coords is not None: pre_start = len(self.old_coords) new_coords = np.concatenate([self.old_coords, new_coords], 0) else: pre_start = 0 vel = np.sum(np.diff(new_coords[:, :2], axis=0) ** 2, axis=1) self.vmax = np.nanmax(vel) self.lbl_vmax.setText("max velocity sq {:.1f}".format(self.vmax)) if self.bout_coords is None: self.bout_coords = [new_coords[0, :]] new_bout = None if self.detection_params.threshold > 0: self.bout_coords, bout_finished, self.bout_state = find_bouts_online( vel, new_coords, self.bout_state, bout_coords=self.bout_coords, shift=pre_start, **self.detection_params.params.values, ) if bout_finished: new_bout = self.bout_coords self.bout_coords = [new_coords[0, :]] self.old_coords = new_coords[-self.n_save_max :, :] self.colors *= self.decay_constant if new_bout and len(new_bout) > 2: nb = normalise_bout(np.array(new_bout[1:])) self.bout_curves[self.i_curve].setData(x=nb[:, 0], y=nb[:, 1]) self.colors[self.i_curve] = 255 self.i_curve = (self.i_curve + 1) % self.n_bouts for i_c, (curve, color) in enumerate(zip(self.bout_curves, self.colors)): col = int(color) if col < 10: curve.setData(x=[], y=[]) else: curve.setPen((col, col, col))