Source code for stytra.tracking.fish

import cv2
import numpy as np
from numba import jit, jitclass, int64, float64

from stytra.tracking.tail import find_fish_midline
from stytra.tracking.preprocessing import BackgroundSubtractor

from itertools import chain

from lightparam import Param
from stytra.tracking.simple_kalman import predict_inplace, update_inplace
from stytra.tracking.pipelines import ImageToDataNode, NodeOutput
from collections import namedtuple


def _fish_column_names(i_fish, n_segments):
    return [
        "f{:d}_x".format(i_fish),
        "f{:d}_vx".format(i_fish),
        "f{:d}_y".format(i_fish),
        "f{:d}_vy".format(i_fish),
        "f{:d}_theta".format(i_fish),
        "f{:d}_vtheta".format(i_fish),
    ] + ["f{:d}_theta_{:02d}".format(i_fish, i) for i in range(n_segments)]


[docs]class FishTrackingMethod(ImageToDataNode): def __init__(self, *args, **kwargs): super().__init__(*args, name="fish_tracking", **kwargs) self.monitored_headers = ["biggest_area", "f0_theta"] self.diagnostic_image_options = [ "background difference", "thresholded background difference", "fish detection", "thresholded for eye and swim bladder", ] self.dilation_kernel = np.ones((3, 3), dtype=np.uint8) self.fishes = None
[docs] def changed(self, vals): if any( p in vals.keys() for p in ["n_segments", "n_fish_max", "bg_downsample"] ) or vals.get("reset", False): self.reset()
[docs] def reset(self): self._output_type = namedtuple( "t", list( chain.from_iterable( [ _fish_column_names(i_fish, self._params.n_segments - 1) for i_fish in range(self._params.n_fish_max) ] ) ) + ["biggest_area"], ) self._output_type_changed = True # used for booking a spot for one of the potentially tracked fish self.fishes = Fishes( self._params.n_fish_max, n_segments=self._params.n_segments - 1, pos_std=self._params.pos_uncertainty, pred_coef=self._params.prediction_uncertainty, angle_std=np.pi / 10, persist_fish_for=self._params.persist_fish_for, )
def _process( self, bg, n_fish_max: Param(1, (1, 50)), n_segments: Param(10, (2, 30)), bg_downsample: Param(1, (1, 8)), bg_dif_threshold: Param(25, (0, 255)), threshold_eyes: Param(35, (0, 255)), pos_uncertainty: Param( 1.0, (0, 10.0), desc="Uncertainty in pixels about the location of the head center of mass", ), persist_fish_for: Param( 2, (1, 50), desc="How many frames does the fish persist for if it is not detected", ), prediction_uncertainty: Param(0.1, (0.0, 10.0, 0.0001)), fish_area: Param((200, 1200), (1, 4000)), border_margin: Param(5, (0, 100)), tail_length: Param(60.0, (1.0, 200.0)), tail_track_window: Param(3, (3, 70)), ): # update the previously-detected fish using the Kalman filter if self.fishes is None: self.reset() else: self.fishes.predict() area_scale = bg_downsample * bg_downsample border_margin = border_margin // bg_downsample # downsample background if bg_downsample > 1: bg_small = cv2.resize(bg, None, fx=1 / bg_downsample, fy=1 / bg_downsample) else: bg_small = bg bg_thresh = cv2.dilate( (bg_small > bg_dif_threshold).view(dtype=np.uint8), self.dilation_kernel ) # find regions where there is a difference with the background n_comps, labels, stats, centroids = cv2.connectedComponentsWithStats(bg_thresh) try: max_area = np.max(stats[1:, cv2.CC_STAT_AREA]) * area_scale except ValueError: max_area = 0 # iterate through all the regions different from the background and try # to find fish messages = [] nofish = True for row, centroid in zip(stats, centroids): # check if the contour is fish-sized and central enough if not fish_area[0] < row[cv2.CC_STAT_AREA] * area_scale < fish_area[1]: continue # find the bounding box of the fish in the original image coordinates ftop, fleft, fheight, fwidth = ( int(round(row[x] * bg_downsample)) for x in [ cv2.CC_STAT_TOP, cv2.CC_STAT_LEFT, cv2.CC_STAT_HEIGHT, cv2.CC_STAT_WIDTH, ] ) if not ( (fleft - border_margin >= 0) and (fleft + fwidth + border_margin < bg.shape[1]) and (ftop - border_margin >= 0) and (ftop + fheight + border_margin < bg.shape[0]) ): messages.append("W:An object of right area found outside margins") continue # how much is this region shifted from the upper left corner of the image cent_shift = np.array([fleft - border_margin, ftop - border_margin]) slices = ( slice(ftop - border_margin, ftop + fheight + border_margin), slice(fleft - border_margin, fleft + fwidth + border_margin), ) # take the region and mask the background away to aid detection fishdet = bg[slices].copy() # estimate the position of the head fish_coords = fish_start(fishdet, threshold_eyes) # if no actual fish was found here, continue on to the next connected component if fish_coords[0] == -1: messages.append("W:No appropriate tail start position found") continue head_coords_up = fish_coords + cent_shift theta = _fish_direction_n(bg, head_coords_up, int(round(tail_length / 2))) # find the points of the tail points = find_fish_midline( bg, *head_coords_up, theta, tail_track_window, tail_length / n_segments, n_segments + 1, ) # convert to angles angles = np.mod(points_to_angles(points) + np.pi, np.pi * 2) - np.pi if len(angles) == 0: messages.append("W:Tail not completely detectable") continue # also, make the angles continuous angles[1:] = np.unwrap(angles[1:] - angles[0]) # put the data together for one fish fish_coords = np.concatenate([np.array(points[0][:2]), angles]) nofish = False # check if this is a new fish, or it is an update of # a fish detected previously if self.fishes.update(fish_coords): messages.append("I:Updated previous fish") elif self.fishes.add_fish(fish_coords): messages.append("I:Added new fish") else: messages.append("E:More fish than n_fish max") if nofish: messages.append( "W:No object of right area, between {:.0f} and {:.0f}".format( *fish_area ) ) # if a debugging image is to be shown, set it if self.set_diagnostic == "background difference": self.diagnostic_image = bg elif self.set_diagnostic == "thresholded background difference": self.diagnostic_image = bg_thresh elif self.set_diagnostic == "fish detection": fishdet = bg_small.copy() fishdet[bg_thresh == 0] = 0 self.diagnostic_image = fishdet elif self.set_diagnostic == "thresholded for eye and swim bladder": self.diagnostic_image = np.maximum(bg, threshold_eyes) - threshold_eyes if self._output_type is None: self.reset_state() return NodeOutput( messages, self._output_type(*self.fishes.coords.flatten(), max_area * 1.0) )
spec = [ ("n_fish", int64), ("coords", float64[:, :]), ("i_not_updated", int64[:]), ("F", float64[:, :]), ("uncertainties", float64[:]), ("Q", float64[:, :]), ("Ps", float64[:, :, :, :]), ("def_P", float64[:, :, :]), ("persist_fish_for", int64), ] @jitclass(spec) class Fishes(object): def __init__( self, n_fish_max, pos_std, angle_std, n_segments, pred_coef, persist_fish_for ): self.n_fish = n_fish_max self.coords = np.full((n_fish_max, 6 + n_segments), np.nan) self.uncertainties = np.array((pos_std, angle_std, angle_std)) self.def_P = np.zeros((3, 2, 2)) for i, uc in enumerate(self.uncertainties): self.def_P[i, 0, 0] = uc self.def_P[i, 1, 1] = uc self.i_not_updated = np.zeros(n_fish_max, dtype=np.int64) self.Ps = np.zeros((n_fish_max, 3, 2, 2)) self.F = np.array([[1.0, 1.0], [0.0, 1.0]]) dt = 0.02 self.Q = ( np.array([[0.25 * dt ** 4, 0.5 * dt ** 3], [0.5 * dt ** 3, dt ** 2]]) * pred_coef ) self.persist_fish_for = persist_fish_for def predict(self): for i_fish in range(self.n_fish): if not np.isnan(self.coords[i_fish, 0]): for i_coord in range(0, 6, 2): predict_inplace( self.coords[i_fish, i_coord : i_coord + 2], self.Ps[i_fish, i_coord // 2], self.F, self.Q, ) self.i_not_updated[i_fish] += 1 if self.i_not_updated[i_fish] > self.persist_fish_for: self.coords[i_fish, :] = np.nan def update(self, new_fish): for i_fish in range(self.n_fish): if not np.isnan(self.coords[i_fish, 0]): if self.is_close(new_fish, i_fish) and self.i_not_updated[i_fish] != 0: # update position with Kalman filtering for i_coord in range(0, 3): # if it is the angle find the modulo 2pi closest nc = new_fish[i_coord] if i_coord == 2: nc = _minimal_angle_dif(self.coords[i_fish, 4], nc) update_inplace( nc, self.coords[i_fish, i_coord * 2 : i_coord * 2 + 2], self.Ps[i_fish, i_coord], self.uncertainties[i_coord], ) # update tail angles self.coords[i_fish, 6:] = new_fish[3:] self.i_not_updated[i_fish] = 0 return True def add_fish(self, new_fish): for i_fish in range(self.n_fish): if np.isnan(self.coords[i_fish, 0]): self.coords[i_fish, 0:6:2] = new_fish[:3] self.coords[i_fish, 1:6:2] = 0.0 self.coords[i_fish, 6:] = new_fish[3:] self.Ps[i_fish] = self.def_P self.i_not_updated[i_fish] = 0 return True return False def is_close(self, new_fish, i_fish): """ Check whether the new coordinates are within a certain number of pixels of the old estimate and within a certain angle """ n_px = 15 d_theta = np.pi / 2 dists = new_fish[:2] - self.coords[i_fish, 0:4:2] dtheta = np.abs( np.mod(new_fish[2] - self.coords[i_fish, 4] + np.pi, np.pi * 2) - np.pi ) return np.sum(dists ** 2) < n_px ** 2 and dtheta < d_theta
[docs]@jit(nopython=True) def points_to_angles(points): angles = np.empty(len(points) - 1, dtype=np.float64) for i, (p1, p2) in enumerate(zip(points[0:-1], points[1:])): angles[i] = np.arctan2(p2[1] - p1[1], p2[0] - p1[0]) return angles
[docs]@jit(nopython=True) def fish_start(mask, take_min): su = 0.0 ret = np.full((2,), 0.0) for i in range(mask.shape[0]): for j in range(mask.shape[1]): if mask[i, j] > take_min: dm = mask[i, j] - take_min ret[1] += dm * i ret[0] += dm * j su += dm if su > 0.0: return ret / su else: ret[:] = -1 return ret
# Utilities for drawing circles. @jit(nopython=True) def _symmetry_points(x0, y0, x, y): return [ (x0 + x, y0 + y), (x0 - x, y0 + y), (x0 + x, y0 - y), (x0 - x, y0 - y), (x0 + y, y0 + x), (x0 - y, y0 + x), (x0 + y, y0 - x), (x0 - y, y0 - x), ] @jit(nopython=True) def _circle_points(x0, y0, radius): """ Bresenham's circle algorithm Parameters ---------- xc : center x yc : center y r : radius Returns ------- a list of points """ f = 1 - radius ddf_x = 1 ddf_y = -2 * radius x = 0 y = radius points = [ (x0, y0 + radius), (x0, y0 - radius), (x0 + radius, y0), (x0 - radius, y0), ] while x < y: if f >= 0: y -= 1 ddf_y += 2 f += ddf_y x += 1 ddf_x += 2 f += ddf_x points.extend(_symmetry_points(x0, y0, x, y)) return points @jit(nopython=True) def _fish_direction_n(image, start_loc, radius): centre_int = start_loc.astype(np.int16) pixels_rad = _circle_points(centre_int[0], centre_int[1], radius) max_point = pixels_rad[0] max_val = 0 h, w = image.shape for x, y in pixels_rad: if x < 0 or y < 0 or x >= w or y >= h: continue if image[y, x] > max_val: max_val = image[y, x] max_point = (x, y) return np.arctan2(max_point[1] - centre_int[1], max_point[0] - centre_int[0]) @jit(nopython=True) def _minimal_angle_dif(th_old, th_new): return th_old + np.mod(th_new - th_old + np.pi, np.pi * 2) - np.pi