import numpy as np
import datetime
from stytra.collectors import QueueDataAccumulator
from stytra.utilities import reduce_to_pi
from collections import namedtuple
[docs]class Estimator:
"""
An estimator is an object that estimate quantities required for the
control of the stimulus (animal position/speed etc.) from the output
stream of the tracking pipelines (position in pixels, tail angles, etc.).
"""
def __init__(self, acc_tracking: QueueDataAccumulator, experiment):
self.exp = experiment
self.log = experiment.estimator_log
self.acc_tracking = acc_tracking
[docs] def reset(self):
self.log.reset()
[docs]class VigorMotionEstimator(Estimator):
"""
A very common way of estimating velocity of an embedded animal is
vigor, computed as the standard deviation of the tail cumulative angle in a
specified time window - generally 50 ms.
"""
def __init__(self, *args, vigor_window=0.050, base_gain=-12, **kwargs):
super().__init__(*args, **kwargs)
self.vigor_window = vigor_window
self.last_dt = 1 / 500.0
self.base_gain = base_gain
self._output_type = namedtuple("s", "vigor")
[docs] def get_velocity(self, lag=0):
"""
Parameters
----------
lag :
(Default value = 0)
Returns
-------
"""
vigor_n_samples = max(int(round(self.vigor_window / self.last_dt)), 2)
n_samples_lag = max(int(round(lag / self.last_dt)), 0)
if not self.acc_tracking.stored_data:
return 0
past_tail_motion = self.acc_tracking.get_last_n(
vigor_n_samples + n_samples_lag
)[0:vigor_n_samples]
end_t = past_tail_motion.t.iloc[-1]
start_t = past_tail_motion.t.iloc[0]
new_dt = (end_t - start_t) / vigor_n_samples
if new_dt > 0:
self.last_dt = new_dt
vigor = np.nanstd(np.array(past_tail_motion.tail_sum))
if np.isnan(vigor):
vigor = 0
if len(self.log.times) == 0 or self.log.times[-1] < end_t:
self.log.update_list(end_t, self._output_type(vigor))
return vigor * self.base_gain
[docs]def rot_mat(theta):
"""The rotation matrix for an angle theta """
return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
[docs]class PositionEstimator(Estimator):
def __init__(self, *args, change_thresholds=None, velocity_window=10, **kwargs):
""" Uses the projector-to-camera calibration to give fish position in
scree coordinates. If change_thresholds are set, update only the fish
position after there is a big enough change (which prevents small
oscillations due to tracking)
:param args:
:param calibrator:
:param change_thresholds: a 3-tuple of thresholds, in px and radians
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.calibrator = self.exp.calibrator
self.last_location = None
self.past_values = None
self.velocity_window = velocity_window
self.change_thresholds = change_thresholds
if change_thresholds is not None:
self.change_thresholds = np.array(change_thresholds)
self._output_type = namedtuple("f", ["x", "y", "theta"])
[docs] def get_camera_position(self):
past_coords = {
name: value
for name, value in zip(
self.acc_tracking.columns, self.acc_tracking.get_last_n(1)[0, :]
)
}
return past_coords["f0_x"], past_coords["f0_y"], past_coords["f0_theta"]
[docs] def get_velocity(self):
vel = np.diff(
self.acc_tracking.get_last_n(self.velocity_window)[["f0_x", "f0_y"]].values,
0,
)
return np.sqrt(np.sum(vel ** 2))
[docs] def get_istantaneous_velocity(self):
vel_xy = self.acc_tracking.get_last_n(self.velocity_window)[
["f0_vx", "f0_vy"]
].values
return np.sqrt(np.sum(vel_xy ** 2))
[docs] def reset(self):
super().reset()
self.past_values = None
[docs] def get_position(self):
if len(self.acc_tracking.stored_data) == 0 or not np.isfinite(
self.acc_tracking.stored_data[-1].f0_x
):
o = self._output_type(np.nan, np.nan, np.nan)
return o
past_coords = self.acc_tracking.stored_data[-1]
t = self.acc_tracking.times[-1]
if not self.calibrator.cam_to_proj is None:
projmat = np.array(self.calibrator.cam_to_proj)
if projmat.shape != (2, 3):
projmat = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
x, y = projmat @ np.array([past_coords.f0_x, past_coords.f0_y, 1.0])
theta = np.arctan2(
*(
projmat[:, :2]
@ np.array(
[np.cos(past_coords.f0_theta), np.sin(past_coords.f0_theta)]
)[::-1]
)
)
else:
x, y, theta = past_coords.f0_x, past_coords.f0_y, past_coords.f0_theta
c_values = np.array((y, x, theta))
if self.change_thresholds is not None:
if self.past_values is None:
self.past_values = np.array(c_values)
else:
deltas = c_values - self.past_values
deltas[2] = reduce_to_pi(deltas[2])
sel = np.abs(deltas) > self.change_thresholds
self.past_values[sel] = c_values[sel]
c_values = self.past_values
logout = self._output_type(*c_values)
self.log.update_list(t, logout)
return c_values
estimator_dict = dict(position=PositionEstimator, vigor=VigorMotionEstimator)