diff --git a/pykick/finders.py b/pykick/finders.py index 36fe4e4..1f79a9d 100644 --- a/pykick/finders.py +++ b/pykick/finders.py @@ -1,3 +1,5 @@ +"""Classes for object detection.""" + from __future__ import division from __future__ import print_function @@ -10,12 +12,34 @@ from .utils import hsv_mask class FieldFinder(object): + """Finds the contour of the field.""" def __init__(self, hsv_lower, hsv_upper): + """ + Parameters + ---------- + hsv_lower, hsv_upper : list + HSV interval of the field in format [H, S, V] + + """ + self.hsv_lower = tuple(hsv_lower) self.hsv_upper = tuple(hsv_upper) def primary_mask(self, frame): + """Apply thresholding to the camera image. + + Parameters + ---------- + frame : array + OpenCV Image in BGR format + + Returns + ------- + array + OpenCV 8-bit 1-channel mask + + """ hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) blurred = cv2.GaussianBlur(hsv, (25, 25), 20) thr = hsv_mask(blurred, self.hsv_lower, self.hsv_upper) @@ -24,6 +48,20 @@ class FieldFinder(object): return thr def find(self, frame): + """Find the contour of the field. + + Parameters + ---------- + frame : array + OpenCV Image in BGR format. + + Returns + ------- + contour or None + OpenCV contour of the field or None if wasn't found. + + """ + thr = self.primary_mask(frame) cnts, _ = cv2.findContours(thr.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) @@ -34,12 +72,44 @@ class FieldFinder(object): return field def draw(self, frame, field): + """Draw the contour on the image (demo purposes). + + Parameters + ---------- + frame : array + OpenCV Image in 3-channel format. + field : contour + OpenCV contour of the field as returned by `find`. + + Returns + ------- + array + New image with the field contour drawn. + + """ if field is not None: frame = frame.copy() cv2.drawContours(frame, (field,), -1, (0, 0, 255), 2) return frame def mask_it(self, frame, field, inverse=False): + """Mask out the field or everything else in the image. + + Parameters + ---------- + frame : array + OpenCV Image in 3-channel format. + field : contour + OpenCV contour of the field as returned by `find`. + inverse : bool + If True, mask out the field, if False, everything else. + + Returns + ------- + array + New image with masked out something. + + """ if field is not None: mask = np.zeros(frame.shape[:2], dtype=np.uint8) cv2.drawContours(mask, (field,), -1, 255, -1) @@ -50,12 +120,33 @@ class FieldFinder(object): class GoalFinder(object): + """Find a massive distinctly single-colored goal frame.""" def __init__(self, hsv_lower, hsv_upper): + """ + Parameters + ---------- + hsv_lower, hsv_upper : list + HSV interval of the field in format [H, S, V] + + """ self.hsv_lower = tuple(hsv_lower) self.hsv_upper = tuple(hsv_upper) def primary_mask(self, frame): + """Apply thresholding to the camera image. + + Parameters + ---------- + frame : array + OpenCV Image in BGR format + + Returns + ------- + array + OpenCV 8-bit 1-channel mask + + """ hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) thr = hsv_mask(hsv, self.hsv_lower, self.hsv_upper) thr = cv2.erode(thr, None, iterations=2) @@ -63,6 +154,19 @@ class GoalFinder(object): return thr def goal_similarity(self, contour): + """Calculate the similarity of a contour to the goal. + + Parameters + ---------- + contour : contour + An OpenCV contour. + + Returns + ------- + float + Similarity (or dissimilarity bcs the smaller the more similar). + + """ hull = cv2.convexHull(contour).squeeze() len_h = cv2.arcLength(hull, True) @@ -86,6 +190,19 @@ class GoalFinder(object): return final_score def find(self, frame): + """Find the contour of the goal. + + Parameters + ---------- + frame : array + An OpenCV image in BGR format. + + Returns + ------- + contour or None + An OpenCV contour of the goal or None if wasn't found. + + """ thr = self.primary_mask(frame) cnts, _ = cv2.findContours(thr, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) @@ -118,15 +235,32 @@ class GoalFinder(object): return goal def left_right_post(self, contour): + """Return the pixel coordinates of the L-R goalpost.""" return contour[...,0].min(), contour[...,0].max() def goal_center(self, contour): + """Return the center of the goal in pixel coordinates.""" l, r = self.left_right_post(contour) print('Left goal post:', l, 'Right goal post:', r) return (l + r) / 2 def draw(self, frame, goal): + """Draw the contour on the image (demo purposes). + + Parameters + ---------- + frame : array + OpenCV Image in 3-channel format. + field : contour + OpenCV contour of the goal as returned by `find`. + + Returns + ------- + array + New image with the goal contour drawn. + + """ if goal is not None: frame = frame.copy() cv2.drawContours(frame, (goal,), -1, (0, 255, 0), 2) @@ -134,20 +268,55 @@ class GoalFinder(object): class BallFinder(object): + """Class to find the red ball.""" def __init__(self, hsv_lower, hsv_upper, min_radius=0.02): + """ + Parameters + ---------- + hsv_lower, hsv_upper : list + HSV interval of the ball in format [H, S, V]. + min_radius : float + The minimal radius of the ball as fraction of image height. + """ self.hsv_lower = tuple(hsv_lower) self.hsv_upper = tuple(hsv_upper) self.min_radius = min_radius self.history = deque(maxlen=64) def primary_mask(self, frame): + """Apply thresholding to the camera image. + + Parameters + ---------- + frame : array + OpenCV Image in BGR format + + Returns + ------- + array + OpenCV 8-bit 1-channel mask + + """ hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) mask = hsv_mask(hsv, self.hsv_lower, self.hsv_upper) return mask def find(self, frame): + """Find the x, y ball coordinates and the radius. + + Parameters + ---------- + frame : array + An OpenCV image in BGR format. + + Returns + ------- + tuple or None + ((x, y), radius) or None if wasn't found + + """ mask = self.primary_mask(frame) cnts, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) @@ -179,6 +348,21 @@ class BallFinder(object): return center, int(radius) def draw(self, frame, ball): + """Draw the contour on the image (demo purposes). + + Parameters + ---------- + frame : array + OpenCV Image in 3-channel format. + field : contour + tuple describing the ball as returned by `find`. + + Returns + ------- + array + New image with the field contour drawn. + + """ if ball is not None: frame = frame.copy() center, radius = ball